chore(rename): move llama_stack.distribution to llama_stack.core (#2975)

We would like to rename the term `template` to `distribution`. To
prepare for that, this is a precursor.

cc @leseb
This commit is contained in:
Ashwin Bharambe 2025-07-30 23:30:53 -07:00 committed by GitHub
parent f3d5459647
commit 2665f00102
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
211 changed files with 351 additions and 348 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import (
AccessRule,
Action,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource: ProtectedResource,
user: User,
) -> bool:
for condition in conditions:
# must match all conditions
if not condition.matches(resource, user):
return False
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided, assume
# full access subject to previous attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
),
]
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
self.action = action
self.resource = resource
self.user = user
message = _build_access_denied_message(action, resource, user)
super().__init__(message)
def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str:
"""Build detailed error message for access denied scenarios."""
if action and resource and user:
resource_info = f"{resource.type}::{resource.identifier}"
user_info = f"'{user.principal}'"
if user.attributes:
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
user_info += f" (attributes: {attrs})"
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
else:
message = "Insufficient permissions"
return message

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Self
from pydantic import BaseModel, model_validator
from .conditions import parse_conditions
class Action(StrEnum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_db::*
unless: user with admin in roles
description: only user with admin role can use vector_db resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: str | list[str] | None = None
unless: str | list[str] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self

177
llama_stack/core/build.py Normal file
View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import logging
import sys
from pathlib import Path
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.templates.template import DistributionTemplate
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
class ApiInput(BaseModel):
api: Api
provider: str
def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
if isinstance(config, DistributionTemplate):
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
external_provider_deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different - not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
# this ensures we install the top level module for our external providers
if provider_spec.module:
if isinstance(provider_spec.module, str):
external_provider_deps.append(provider_spec.module)
else:
external_provider_deps.extend(provider_spec.module)
if hasattr(provider_spec, "pip_packages"):
deps.extend(provider_spec.pip_packages)
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
raise ValueError("A stack's dependencies cannot have a container image")
normal_deps = []
special_deps = []
for package in deps:
if "--no-deps" in package or "--index-url" in package:
special_deps.append(package)
else:
normal_deps.append(package)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
color="yellow",
file=sys.stderr,
)
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()
def build_image(
build_config: BuildConfig,
build_file_path: Path,
image_name: str,
template_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
args = [
script,
"--template-or-config",
template_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.extend(["--run-config", run_config])
elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "core/build_conda_env.sh")
args = [
script,
"--env-name",
str(image_name),
"--build-file-path",
str(build_file_path),
"--normal-deps",
" ".join(normal_deps),
]
elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)
if return_code != 0:
log.error(
f"Failed to build target {image_name} with return code {return_code}",
)
return return_code

View file

@ -0,0 +1,207 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
build_file_path=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--build-file-path)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --build-file-path requires a string value" >&2
usage
fi
build_file_path="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
ensure_conda_env_python310() {
# Use only global variables set by flag parser
local python_version="3.12"
if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
conda install -n "${env_name}" python="${python_version}" -y
fi
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
"$CONDA_PREFIX"/bin/pip install uv
if [ -n "$TEST_PYPI_VERSION" ]; then
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-stack=="$TEST_PYPI_VERSION" \
"$normal_deps"
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
PYPI_VERSION="${PYPI_VERSION:-}"
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
uv pip install --no-cache-dir "$SPEC_VERSION"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
printf "Installing pip dependencies\n"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
}
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"

View file

@ -0,0 +1,411 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
BUILD_PLATFORM=${BUILD_PLATFORM:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml
BUILD_CONTEXT_DIR=$(pwd)
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
# Usage function
usage() {
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
image_name=""
container_base=""
normal_deps=""
external_provider_deps=""
optional_deps=""
run_config=""
template_or_config=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--image-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --image-name requires a string value" >&2
usage
fi
image_name="$2"
shift 2
;;
--container-base)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --container-base requires a string value" >&2
usage
fi
container_base="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
--run-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --run-config requires a string value" >&2
usage
fi
run_config="$2"
shift 2
;;
--template-or-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --template-or-config requires a string value" >&2
usage
fi
template_or_config="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
usage
fi
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() {
output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
cat >>"$output_file"
fi
}
if ! is_command_available "$CONTAINER_BINARY"; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF
FROM $container_base
WORKDIR /app
# We install the Python 3.12 dev headers and build tools so that any
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
EOF
else
add_to_container << EOF
FROM $container_base
WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet git\
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
EOF
fi
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF
RUN uv pip install --no-cache $quoted_deps
EOF
fi
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN uv pip install --no-cache $quoted_deps
EOF
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN uv pip install --no-cache $quoted_deps
EOF
add_to_container <<EOF
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
import importlib
import sys
try:
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
module = importlib.import_module(f'{package_name}.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
if isinstance(spec.pip_packages, (list, tuple)):
print('\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
PYTHON
EOF
done
fi
get_python_cmd() {
if is_command_available python; then
echo "python"
elif is_command_available python3; then
echo "python3"
else
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
exit 1
fi
}
if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
# Parse the run.yaml configuration to identify external provider directories
# If external providers are specified, copy their directory to the container
# and update the configuration to reference the new container path
python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ]; then
if [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY providers.d /.llama/providers.d
EOF
fi
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi
fi
# Copy run config into docker image
add_to_container << EOF
COPY run.yaml $RUN_CONFIG_PATH
EOF
fi
stack_mount="/app/llama-stack-source"
client_mount="/app/llama-stack-client-source"
install_local_package() {
local dir="$1"
local mount_point="$2"
local name="$3"
if [ ! -d "$dir" ]; then
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
exit 1
fi
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
add_to_container << EOF
COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $mount_point
EOF
}
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
add_to_container << EOF
RUN uv pip install --no-cache fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION
EOF
else
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
add_to_container << EOF
RUN uv pip install --no-cache $SPEC_VERSION
EOF
fi
fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# If a run config is provided, we use the --config flag
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "--config", "$RUN_CONFIG_PATH"]
EOF
# If a template is provided (not a yaml file), we use the --template flag
elif [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "--template", "$template_or_config"]
EOF
fi
# Add other require item commands genearic to all containers
add_to_container << EOF
RUN mkdir -p /.llama /.cache && chmod -R g+rw /app /.llama /.cache
EOF
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat "$TEMP_DIR"/Containerfile
printf "\n"
# Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi
fi
if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CLI_ARGS+=("--security-opt" "label=disable")
fi
# Set version tag based on PyPI version
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
version_tag="dev"
else
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
# Add version tag to image name
image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then
CLI_ARGS+=("--platform" "linux/amd64")
else
echo "Unsupported architecture: $ARCH"
exit 1
fi
echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile"
set -x
$CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"$BUILD_CONTEXT_DIR"
# clean up tmp/configs
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
set +x
echo "Success!"

207
llama_stack/core/build_venv.sh Executable file
View file

@ -0,0 +1,207 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: combine this with build_conda_env.sh since it is almost identical
# the only difference is that we don't do any conda-specific setup
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
echo "Error: --env-name and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
# pre-run checks to make sure we can proceed with the installation
pre_run_checks() {
local env_name="$1"
if ! is_command_available uv; then
echo "uv is not installed, trying to install it."
if ! is_command_available pip; then
echo "pip is not installed, cannot automatically install 'uv'."
echo "Follow this link to install it:"
echo "https://docs.astral.sh/uv/getting-started/installation/"
exit 1
else
pip install uv
fi
fi
# checking if an environment with the same name already exists
if [ -d "$env_name" ]; then
echo "Environment '$env_name' already exists, re-using it."
fi
}
run() {
# Use only global variables set by flag parser
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment"
export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
source "$env_name/bin/activate"
fi
if [ -n "$TEST_PYPI_VERSION" ]; then
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack=="$TEST_PYPI_VERSION" \
$normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
uv pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
printf "Installing pip dependencies\n"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "Installing special provider module: $part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Installing external provider module: $part"
uv pip install "$part"
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
}
pre_run_checks "$env_name"
run

189
llama_stack/core/client.py Normal file
View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol) -> type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
# TODO: make this more precise, same thing needs to happen in server.py
is_streaming = kwargs.get("stream", False)
if is_streaming:
return self._call_streaming(method_name, *args, **kwargs)
else:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
_, sig = self.routes[method_name]
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if j is None:
return None
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
data = json.loads(data)
if "error" in data:
cprint(data, color="red", file=sys.stderr)
continue
yield parse_obj_as(return_type, data)
except Exception as e:
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
def convert(value):
if isinstance(value, list):
return [convert(v) for v in value]
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
params = {}
data = {}
if webmethod.method == "GET":
params.update(kwargs)
else:
data.update(convert(kwargs))
ret = dict(
method=webmethod.method or "POST",
url=url,
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
timeout=30,
)
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
_CLIENT_CLASSES[protocol] = APIClient
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None

51
llama_stack/core/common.sh Executable file
View file

@ -0,0 +1,51 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
envname="$1"
set +x
echo "Cleaning up..."
conda deactivate
conda env remove --name "$envname" -y
}
handle_int() {
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
fi
}
setup_cleanup_handlers() {
trap handle_int INT
trap handle_exit EXIT
if is_command_available conda; then
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
eval "$__conda_setup"
conda deactivate
else
echo "conda is not available"
exit 1
fi
}
# check if a command is present
is_command_available() {
command -v "$1" &>/dev/null
}

View file

@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.core.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.model_dump(),
)
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
provider_type=provider.provider_type,
config={},
),
)
)
logger.info("")
config.providers[api_str] = updated_providers
return config
def upgrade_from_routing_table(
config_dict: dict[str, Any],
) -> dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
return config_dict
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -0,0 +1,463 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
super().__init__(principal=principal, attributes=attributes)
class ResourceWithOwner(Resource):
"""Extension of Resource that adds an optional owner, i.e. the user that created the
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects
class ModelWithOwner(Model, ResourceWithOwner):
pass
class ShieldWithOwner(Shield, ResourceWithOwner):
pass
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
pass
class DatasetWithOwner(Dataset, ResourceWithOwner):
pass
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolWithOwner(Tool, ResourceWithOwner):
pass
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
| ShieldWithOwner
| VectorDBWithOwner
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
container_image: str | None = None
routing_table_api: Api
module: str
provider_data_validator: str | None = Field(
default=None,
)
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
container_image: str | None = None
router_api: Api
module: str
pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class BuildProvider(BaseModel):
provider_type: str
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel):
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: str | None = None
providers: dict[str, list[BuildProvider]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.
""",
)
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "roles",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel):
"""Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
...,
description="Authentication provider configuration",
)
access_policy: list[AccessRule] = Field(
default=[],
description="Rules for determining access to resources",
)
class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(StrEnum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
description="Port to listen on",
ge=1024,
le=65535,
)
tls_certfile: str | None = Field(
default=None,
description="Path to TLS certificate file for HTTPS",
)
tls_keyfile: str | None = Field(
default=None,
description="Path to TLS key file for HTTPS",
)
tls_cafile: str | None = Field(
default=None,
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
)
auth: AuthenticationConfig | None = Field(
default=None,
description="Authentication configuration for the server",
)
host: str | None = Field(
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
image_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
container_image: str | None = Field(
default=None,
description="Reference to the container image if this package refers to a container",
)
apis: list[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
providers: dict[str, list[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
metadata_store: KVStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the distribution registry. If not specified,
a default SQLite store will be used.""",
)
inference_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. If not specified,
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
class BuildConfig(BaseModel):
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(
default="conda",
description="Type of package to build (conda | container | venv)",
)
image_name: str | None = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -0,0 +1,277 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
import os
from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
def stack_apis() -> list[Api]:
return list(Api)
class AutoRoutedApiInfo(BaseModel):
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.benchmarks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
]
def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
# Check if config has external providers
if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
return registry
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListRoutesResponse,
RouteInfo,
VersionInfo,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self, config: DistributionInspectConfig, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_routes(self) -> ListRoutesResponse:
run_config: StackRunConfig = self.config.run_config
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, _ in endpoints
if e.methods is not None
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, _ in endpoints
if e.methods is not None
]
)
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,493 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
construct_stack,
get_stack_run_config_from_template,
replace_env_vars,
)
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
logger = logging.getLogger(__name__)
T = TypeVar("T")
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
logger.error(f"Error converting list {value} into {item_type}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
logger.error(f"Error converting dict {value} into {val_type}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
# We should get rid of any non-discriminated unions in the API schema.
if origin is Union:
for union_type in get_args(annotation):
try:
return convert_to_pydantic(union_type, value)
except Exception:
continue
logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry, provider_data
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs):
loop = self.loop
asyncio.set_event_loop(loop)
if kwargs.get("stream"):
def sync_generator():
try:
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return sync_generator()
else:
try:
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return result
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# template
config = get_stack_run_config_from_template(config_path_or_template_name)
self.config_path_or_template_name = config_path_or_template_name
self.config = config
self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
async def initialize(self) -> bool:
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
color="yellow",
file=sys.stderr,
)
if self.config_path_or_template_name.endswith(".yaml"):
providers: dict[str, list[BuildProvider]] = {}
for api, run_providers in self.config.providers.items():
for provider in run_providers:
providers.setdefault(api, []).append(
BuildProvider(provider_type=provider.provider_type, module=provider.module)
)
providers = dict(providers)
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=providers,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
file=sys.stderr,
)
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
self.route_impls = initialize_route_impls(self.impls)
return True
async def request(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
if self.route_impls is None:
raise ValueError("Client not initialized. Please call initialize() first.")
# Create headers with provider data if available
headers = options.headers or {}
if self.provider_data:
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming(
self,
*,
cast_to: Any,
options: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(filtered_body),
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen():
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(body),
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body:
return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body

View file

@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
for api, providers in safe_config.providers.items():
for p in providers:
# Skip providers that are not enabled
if p.provider_id is None:
continue
ret.append(
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
)
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: dict[str, dict[str, HealthResponse]] = {}
# The timeout has to be long enough to allow all the providers to be checked, especially in
# the case of the inference router health check since it checks all registered inference
# providers.
# The timeout must not be equal to the one set by health method for a given implementation,
# otherwise we will miss some providers.
timeout = 3.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self.provider_data = provider_data or {}
if user:
self.provider_data["__authenticated_user"] = user
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = PROVIDER_DATA_VAR.get()
if not val:
return None
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return None
try:
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!")
return None
def request_provider_data_context(
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -0,0 +1,462 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.store import DistributionRegistry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
pass
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.benchmarks: Benchmarks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.files: Files,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
1. Validating and organizing providers.
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.core.routers",
api_dependencies=[],
deps__=[f"inner-{info.router_api.value}"],
),
)
}
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
),
)
}
return specs
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers
async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(
providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv
visited.add(api_str)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for dep in deps:
if dep not in visited and dep in providers_with_specs:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(api_str)
visited: set[str] = set()
stack: list[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if value.__webmethod__.experimental:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: list[str],
) -> dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import AccessRule, RoutedProtocol
from llama_stack.core.stack import StackRunConfig
from llama_stack.core.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def get_routing_table_impl(
api: Api,
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
from ..routing_tables.models import ModelsRoutingTable
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize()
return impl
async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
) -> Any:
from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter
from .inference import InferenceRouter
from .safety import SafetyRouter
from .tool_runtime import ToolRuntimeRouter
from .vector_io import VectorIORouter
api_to_routers = {
"vector_io": VectorIORouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store, policy)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,
job_id,
)

View file

@ -0,0 +1,627 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
self,
model: str,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
return response_stream
else:
response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")
params = dict(
model=model_obj.identifier,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message.tool_calls = None
return response
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import uuid
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models
embedding_models = [
model
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
vector_db_name,
provider_vector_db_id,
)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=registered_vector_db.provider_id,
provider_vector_db_id=registered_vector_db.provider_resource_id,
)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
all_stores = []
for vector_db in vector_dbs:
try:
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
continue
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = limited_stores[0].id if limited_stores else None
last_id = limited_stores[-1].id if limited_stores else None
return VectorStoreListResponse(
data=limited_stores,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.core.datatypes import (
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithOwner(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)

View file

@ -0,0 +1,266 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
AccessRule,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
user = get_authenticated_user()
if not is_action_allowed(self.policy, "delete", obj, user):
raise AccessDeniedError("delete", obj, user)
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError("create", obj, creator)
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def assert_action_allowed(
self,
action: Action,
type: str,
identifier: str,
) -> None:
"""Fetch a registered object by type/identifier and enforce the given action permission."""
obj = await self.get_object_by_identifier(type, identifier)
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
]
return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ModelNotFoundError(model_id)
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Any
from llama_stack.apis.common.errors import DatasetNotFoundError
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise DatasetNotFoundError(dataset_id)
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata and metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithOwner(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
await self.unregister_object(dataset)

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
refresh = refresh or provider_id not in self.listed_providers
if not refresh:
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
return self.impls_by_provider_id[model.provider_id]
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model)
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.provider_resource_id in model_ids:
# avoid overwriting a non-provider-registered model entry
continue
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.core.datatypes import (
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.core.datatypes import (
ShieldWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithOwner(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.core.datatypes import ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id:
routing_key = toolgroup_id
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,231 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import httpx
from aiohttp import hdrs
from llama_stack.core.datatypes import AuthenticationConfig, User
from llama_stack.core.request_headers import user_from_scope
from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
"""Middleware that authenticates requests using configured authentication provider.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Uses the configured auth provider to validate the token
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
Authentication Request Format for Custom Auth Provider:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Token Validation:
Each provider implements its own token validation logic:
- Kubernetes: Uses TokenReview API to validate service account tokens
- Custom: Sends token to custom endpoint for validation
Attribute-Based Access Control:
The attributes returned by the auth provider are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth provider doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Invalid Authorization header format")
token = auth_header.split("Bearer ", 1)[1]
# Validate token and get access attributes
try:
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message, status=401):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -0,0 +1,388 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field
from llama_stack.core.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubTokenAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes."""
pass
@abstractmethod
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
if claim_key not in claims:
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
if attribute_key in attributes:
attributes[attribute_key].extend(values)
else:
attributes[attribute_key] = values
return attributes
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthConfig):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint."""
if scope is None:
scope = {}
headers = dict(scope.get("headers", []))
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=token,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
class GitHubTokenAuthProvider(AuthProvider):
"""
GitHub token authentication provider that validates GitHub access tokens directly.
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
them against the GitHub API to get user information.
"""
def __init__(self, config: GitHubTokenAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub token by calling the GitHub API.
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
"""
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]
github_data = {
"login": user_info["user"]["login"],
"id": str(user_info["user"]["id"]),
"organizations": user_info.get("organizations", []),
}
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
"""Fetch user info and organizations from GitHub API."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "llama-stack",
}
async with httpx.AsyncClient() as client:
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
user_response.raise_for_status()
user_data = user_response.json()
return {
"user": user_data,
}
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_config = config.provider_config
if isinstance(provider_config, CustomAuthConfig):
return CustomAuthProvider(provider_config)
elif isinstance(provider_config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config)
else:
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.
- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests
current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"
try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(UTC) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")
return await self.app(scope, receive, send)
async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import re
from collections.abc import Callable
from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
webmethod = method.__webmethod__ # type: ignore[attr-defined]
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
routes.append(
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
apis[api] = routes
return apis
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
api_to_routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_routes in api_to_routes.items():
if api not in impls:
continue
for route, webmethod in api_routes:
impl = impls[api]
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
route.path,
webmethod,
)
return route_impls
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises:
ValueError: If no matching endpoint is found
"""
impls = route_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}")

View file

@ -0,0 +1,625 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import functools
import inspect
import json
import logging
import os
import ssl
import sys
import traceback
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.core.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
shutdown_stack,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_template
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.model_dump_json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.errors())
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
await shutdown_stack(app.__llama_stack_impls__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
yield
logger.info("Shutting down")
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen_coroutine):
event_gen = None
try:
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data
new_params = (
[new_params[0]]
+ [
(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
)
route_handler.__signature__ = sig.replace(parameters=new_params)
return route_handler
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
}
}
).encode()
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_template_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_template = get_config_from_args(args)
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
logger_config = None
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config)
app = FastAPI(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route, _ in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
method.lower(),
route.path,
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}
params = [param.split(":")[0] for param in params]
return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

450
llama_stack/core/stack.py Normal file
View file

@ -0,0 +1,450 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib.resources
import os
import re
import tempfile
from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
VectorIO,
Eval,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
):
pass
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
f"or ensure the environment variable is set."
)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
# Special handling for providers: first resolve the provider_id to check if provider
# is disabled so that we can skip config env variable expansion and avoid validation errors
if isinstance(v, dict) and "provider_id" in v:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
logger.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing
pass
# Normal processing for non-disabled providers
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, str):
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
def get_env_var(match: re.Match):
env_var = match.group(1)
operator = match.group(2) # '=' for default, '+' for conditional
value_expr = match.group(3)
env_value = os.environ.get(env_var)
if operator == "=": # Default value syntax: ${env.FOO:=default}
# If the env is set like ${env.FOO:=default} then use the env value when set
if env_value:
value = env_value
else:
# If the env is not set, look for a default value
# value_expr returns empty string (not None) when not matched
# This means ${env.FOO:=} and it's accepted and returns empty string - just like bash
if value_expr == "":
return ""
else:
value = value_expr
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
# If the env is set like ${env.FOO:+value_if_set} then use the value_if_set
if env_value:
if value_expr:
value = value_expr
# This means ${env.FOO:+}
else:
# Just like bash, this doesn't care whether the env is set or not and applies
# the value, in this case the empty string
return ""
else:
# Just like bash, this doesn't care whether the env is set or not, since it's not set
# we return an empty string
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)
try:
result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
if value == "":
return None
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables:
await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
with importlib.resources.as_file(template_path) as path:
if not path.exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config

134
llama_stack/core/start_stack.sh Executable file
View file

@ -0,0 +1,134 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
exit 1
fi
env_type="$1"
shift
env_path_or_name="$1"
container_image="localhost/$env_path_or_name"
shift
port="$1"
shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Initialize variables
yaml_config=""
env_vars=""
other_args=""
# Process remaining arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
else
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
exit 1
fi
;;
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Check if yaml_config is required based on env_type
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
exit 1
fi
PYTHON_BINARY="python"
case "$env_type" in
"venv")
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
fi
;;
"conda")
if ! is_command_available conda; then
echo -e "${RED}Error: conda not found" >&2
exit 1
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_path_or_name"
PYTHON_BINARY="$CONDA_PREFIX/bin/python"
;;
*)
esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="$yaml_config"
else
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.core.server.server \
$yaml_config_arg \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
exit 1
fi

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .registry import * # noqa: F401 F403

View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from contextlib import asynccontextmanager
from typing import Protocol
import pydantic
from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
async def delete(self, type: str, identifier: str) -> None: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def initialize(self) -> None:
pass
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str:
return None
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return obj
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return True
async def delete(self, type: str, identifier: str) -> None:
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
cache[cache_key] = obj
self._initialized = True
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
return self.cache.get((type, identifier), None)
async def get_all(self) -> list[RoutableObjectWithProvider]:
await self._ensure_initialized()
async with self._locked_cache() as cache:
return list(cache.values())
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
return cache.get(cache_key, None)
async def register(self, obj: RoutableObjectWithProvider) -> bool:
await self._ensure_initialized()
success = await super().register(obj)
if success:
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return success
async def update(self, obj: RoutableObjectWithProvider) -> None:
await super().update(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return obj
async def delete(self, type: str, identifier: str) -> None:
await super().delete(type, identifier)
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
del cache[cache_key]
async def create_dist_registry(
metadata_store: KVStoreConfig | None,
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore

View file

@ -0,0 +1,11 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -0,0 +1,50 @@
# (Experimental) LLama Stack UI
## Docker Setup
:warning: This is a work in progress.
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
```
llama stack build --template together --image-type conda
llama stack run together
```
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
```
```bash
llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
3. Start Streamlit UI
```bash
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
def main():
# Evaluation pages
application_evaluation_page = st.Page(
"page/evaluations/app_eval.py",
title="Evaluations (Scoring)",
icon="📊",
default=False,
)
native_evaluation_page = st.Page(
"page/evaluations/native_eval.py",
title="Evaluations (Generation + Scoring)",
icon="📊",
default=False,
)
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",
icon="🔍",
default=False,
)
pg = st.navigation(
{
"Playground": [
chat_page,
rag_page,
tool_page,
application_evaluation_page,
native_evaluation_page,
],
"Inspect": [provider_page, resources_page],
},
expanded=False,
)
pg.run()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from llama_stack_client import LlamaStackClient
class LlamaStackApi:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
)
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = dict.fromkeys(scoring_function_ids)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None
def data_url_from_file(file) -> str:
file_content = file.getvalue()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type = file.type
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def benchmarks():
# Benchmarks Section
st.header("Benchmarks")
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
if len(benchmarks_info) > 0:
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
st.json(benchmarks_info[selected_benchmark], expanded=True)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
apis_providers_lst = llama_stack_api.client.providers.list()
api_to_providers = {}
for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}")
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers()

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from streamlit_option_menu import option_menu
from llama_stack.core.ui.page.distribution.datasets import datasets
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
"Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
icons=icons,
orientation="horizontal",
styles={
"nav-link": {
"font-size": "12px",
},
},
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":
models()
elif selected_resource == "Scoring Functions":
scoring_functions()
elif selected_resource == "Shields":
shields()
resources_page()

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def shields():
# Shields Section
st.header("Shields")
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = llama_stack_api.client.scoring_functions.list()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = llama_stack_api.client.models.list()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = llama_stack_api.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
application_evaluation_page()

View file

@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def select_benchmark_1():
# Select Benchmarks
st.subheader("1. Choose An Eval Task")
benchmarks = llama_stack_api.client.benchmarks.list()
benchmarks = {et.identifier: et for et in benchmarks}
benchmarks_names = list(benchmarks.keys())
selected_benchmark = st.selectbox(
"Choose an eval task.",
options=benchmarks_names,
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
)
with st.expander("View Eval Task"):
st.json(benchmarks[selected_benchmark], expanded=True)
st.session_state["selected_benchmark"] = selected_benchmark
st.session_state["benchmarks"] = benchmarks
if st.button("Confirm", key="confirm_1"):
st.session_state["selected_benchmark_1_next"] = True
def define_eval_candidate_2():
if not st.session_state.get("selected_benchmark_1_next", None):
return
st.subheader("2. Define Eval Candidate")
st.info(
"""
Define the configurations for the evaluation candidate model or agent used for generation.
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
"""
)
with st.expander("Define Eval Candidate", expanded=True):
# Define Eval Candidate
candidate_type = st.radio("Candidate Type", ["model", "agent"])
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
# Sampling Parameters
st.markdown("##### Sampling Parameters")
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
if candidate_type == "model":
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
eval_candidate = {
"type": "model",
"model": selected_model,
"sampling_params": {
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
}
elif candidate_type == "agent":
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
tools_json = st.text_area(
"Tools Configuration (JSON)",
value=json.dumps(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": "ENTER_BRAVE_API_KEY_HERE",
}
]
),
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
height=200,
)
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
st.error("Invalid JSON format for tools configuration")
tools = []
eval_candidate = {
"type": "agent",
"config": {
"model": selected_model,
"instructions": system_prompt,
"tools": tools,
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False,
},
}
st.session_state["eval_candidate"] = eval_candidate
if st.button("Confirm", key="confirm_2"):
st.session_state["selected_eval_candidate_2_next"] = True
def run_evaluation_3():
if not st.session_state.get("selected_eval_candidate_2_next", None):
return
st.subheader("3. Run Evaluation")
# Add info box to explain configurations being used
st.info(
"""
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
"""
)
selected_benchmark = st.session_state["selected_benchmark"]
benchmarks = st.session_state["benchmarks"]
eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id,
)
total_rows = len(rows.data)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
min_value=1,
max_value=total_rows,
value=5,
help="Number of examples from the dataset to evaluate. ",
)
benchmark_config = {
"type": "benchmark",
"eval_candidate": eval_candidate,
"scoring_params": {},
}
with st.expander("View Evaluation Task", expanded=True):
st.json(benchmarks[selected_benchmark], expanded=True)
with st.expander("View Evaluation Task Configuration", expanded=True):
st.json(benchmark_config, expanded=True)
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.data
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
eval_res = llama_stack_api.client.eval.evaluate_rows(
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
benchmark_config=benchmark_config,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for k in eval_res.generations[0].keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")
select_benchmark_1()
define_eval_candidate_2()
run_evaluation_3()
native_evaluation_page()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
stream = st.checkbox("Stream", value=True)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Main chat interface
st.title("🦙 Chat")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Example: What is Llama Stack?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
response = llama_stack_api.client.inference.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
model_id=selected_model,
stream=stream,
sampling_params={
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
)
if stream:
for chunk in response:
if chunk.event.event_type == "progress":
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -0,0 +1,301 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
vector_io_provider = None
for x in providers:
if x.api == "vector_io":
vector_io_provider = x.provider_id
llama_stack_api.client.vector_dbs.register(
vector_db_id=vector_db_name, # Use the user-provided name
embedding_dimension=384,
embedding_model="all-MiniLM-L6-v2",
provider_id=vector_io_provider,
)
# insert documents using the custom vector db name
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
],
)
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -0,0 +1,352 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
import json
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack_client.lib.agents.react.agent import ReActAgent
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
from llama_stack.core.ui.modules.api import llama_stack_api
class AgentType(enum.Enum):
REGULAR = "Regular"
REACT = "ReAct"
def tool_chat_page():
st.title("🛠 Tools")
client = llama_stack_api.client
models = client.models.list()
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
tool_groups = client.toolgroups.list()
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_dbs = []
def reset_agent():
st.session_state.clear()
st.cache_resource.clear()
with st.sidebar:
st.title("Configuration")
st.subheader("Model")
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
st.subheader("Available ToolGroups")
toolgroup_selection = st.pills(
label="Built-in tools",
options=builtin_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of built-in tools from your llama stack server.",
)
if "builtin::rag" in toolgroup_selection:
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
if not vector_dbs:
st.info("No vector databases available for selection.")
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent,
)
mcp_selection = st.pills(
label="MCP Servers",
options=mcp_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of MCP servers registered to your llama stack server.",
)
toolgroup_selection.extend(mcp_selection)
grouped_tools = {}
total_tools = 0
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")
for group_id, tools in grouped_tools.items():
with st.expander(f"🔧 Tools from `{group_id}`"):
for idx, tool in enumerate(tools, start=1):
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
label="Select Agent Type",
options=["Regular", "ReAct"],
on_change=reset_agent,
)
if agent_type == "ReAct":
agent_type = AgentType.REACT
else:
agent_type = AgentType.REGULAR
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=64,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
for i, tool_name in enumerate(toolgroup_selection):
if tool_name == "builtin::rag":
tool_dict = dict(
name="builtin::rag",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
toolgroup_selection[i] = tool_dict
@st.cache_resource
def create_agent():
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input(placeholder=""):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
def response_generator(turn_response):
if st.session_state.get("agent_type") == AgentType.REACT:
return _handle_react_response(turn_response)
else:
return _handle_regular_response(turn_response)
def _handle_react_response(turn_response):
current_step_content = ""
final_answer = None
tool_results = []
for response in turn_response:
if not hasattr(response.event, "payload"):
yield (
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
"The response received is missing an expected `payload` attribute.\n"
"This could indicate a malformed response or an internal issue within the server.\n\n"
f"Error details: {response}"
)
return
payload = response.event.payload
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
current_step_content += payload.delta.text
continue
if payload.event_type == "step_complete":
step_details = payload.step_details
if step_details.step_type == "inference":
yield from _process_inference_step(current_step_content, tool_results, final_answer)
current_step_content = ""
elif step_details.step_type == "tool_execution":
tool_results = _process_tool_execution(step_details, tool_results)
current_step_content = ""
else:
current_step_content = ""
if not final_answer and tool_results:
yield from _format_tool_results_summary(tool_results)
def _process_inference_step(current_step_content, tool_results, final_answer):
try:
react_output_data = json.loads(current_step_content)
thought = react_output_data.get("thought")
action = react_output_data.get("action")
answer = react_output_data.get("answer")
if answer and answer != "null" and answer is not None:
final_answer = answer
if thought:
with st.expander("🤔 Thinking...", expanded=False):
st.markdown(f":grey[__{thought}__]")
if action and isinstance(action, dict):
tool_name = action.get("tool_name")
tool_params = action.get("tool_params")
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
st.json(tool_params)
if answer and answer != "null" and answer is not None:
yield f"\n\n✅ **Final Answer:**\n{answer}"
except json.JSONDecodeError:
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
except Exception as e:
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
return final_answer
def _process_tool_execution(step_details, tool_results):
try:
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
for tool_response in step_details.tool_responses:
tool_name = tool_response.tool_name
content = tool_response.content
tool_results.append((tool_name, content))
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
try:
parsed_content = json.loads(content)
st.json(parsed_content)
except json.JSONDecodeError:
st.code(content, language=None)
else:
with st.expander("⚙️ Observation", expanded=False):
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
except Exception as e:
with st.expander("⚙️ Error in Tool Execution", expanded=False):
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
return tool_results
def _format_tool_results_summary(tool_results):
yield "\n\n**Here's what I found:**\n"
for tool_name, content in tool_results:
try:
parsed_content = json.loads(content)
if tool_name == "web_search" and "top_k" in parsed_content:
yield from _format_web_search_results(parsed_content)
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
yield from _format_results_list(parsed_content["results"])
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
yield from _format_dict_results(parsed_content)
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
yield from _format_list_results(parsed_content)
except json.JSONDecodeError:
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
except (TypeError, AttributeError, KeyError, IndexError) as e:
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
def _format_web_search_results(parsed_content):
for i, result in enumerate(parsed_content["top_k"], 1):
if i <= 3:
title = result.get("title", "Untitled")
url = result.get("url", "")
content_text = result.get("content", "").strip()
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
def _format_results_list(results):
for i, result in enumerate(results, 1):
if i <= 3:
if isinstance(result, dict):
name = result.get("name", result.get("title", "Result " + str(i)))
description = result.get("description", result.get("content", result.get("summary", "")))
yield f"\n- **{name}**\n {description}\n"
else:
yield f"\n- {result}\n"
def _format_dict_results(parsed_content):
yield "\n```\n"
for key, value in list(parsed_content.items())[:5]:
if isinstance(value, str) and len(value) < 100:
yield f"{key}: {value}\n"
else:
yield f"{key}: [Complex data]\n"
yield "```\n"
def _format_list_results(parsed_content):
yield "\n"
for _, item in enumerate(parsed_content[:3], 1):
if isinstance(item, str):
yield f"- {item}\n"
elif isinstance(item, dict) and "text" in item:
yield f"- {item['text']}\n"
elif isinstance(item, dict) and len(item) > 0:
first_value = next(iter(item.values()))
if isinstance(first_value, str) and len(first_value) < 100:
yield f"- {first_value}\n"
def _handle_regular_response(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
if response.event.payload.event_type == "step_progress":
if hasattr(response.event.payload.delta, "text"):
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
if response.event.payload.step_details.tool_calls:
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
else:
yield "No tool_calls present in step_details"
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response_content = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response_content})
tool_chat_page()

View file

@ -0,0 +1,5 @@
llama-stack>=0.2.1
llama-stack-client>=0.2.1
pandas
streamlit
streamlit-option-menu

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
elif isinstance(v, list):
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)
return result
return _redact_dict(data)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
class Mode(StrEnum):
RUN = "run"
BUILD = "build"
def resolve_config_or_template(
config_or_template: str,
mode: Mode = Mode.RUN,
) -> Path:
"""
Resolve a config/template argument to a concrete config file path.
Args:
config_or_template: User input (file path, template name, or built distribution)
mode: Mode resolving for ("run", "build", "server")
Returns:
Path to the resolved config file
Raises:
ValueError: If resolution fails
"""
# Strategy 1: Try as file path first
config_path = Path(config_or_template)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as template name (if no .yaml extension)
if not config_or_template.endswith(".yaml"):
template_config = _get_template_config_path(config_or_template, mode)
if template_config.exists():
logger.info(f"Using template: {template_config}")
return template_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_template, mode))
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
"""Get the config file path for a template."""
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
"""Format a helpful error message for resolution failures."""
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
template_path = _get_template_config_path(config_or_template, mode)
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
available_templates = _get_available_templates()
templates_str = ", ".join(available_templates) if available_templates else "none found"
return f"""Could not resolve config or template '{config_or_template}'.
Tried the following locations:
1. As file path: {Path(config_or_template).resolve()}
2. As template: {template_path}
3. As built distribution: ({distrib_path}, {distrib_path2})
Available templates: {templates_str}
Did you mean one of these templates?
{_format_template_suggestions(available_templates, config_or_template)}
"""
def _get_available_templates() -> list[str]:
"""Get list of available template names."""
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
return []
return list(
set(
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
)
)
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
"""Format template suggestions for error messages, showing closest matches first."""
if not templates:
return " (no templates found)"
import difflib
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
display_templates = close_matches if close_matches else templates[:3]
suggestions = [f" - {t}" for t in display_templates]
return "\n".join(suggestions)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextvars import ContextVar
def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
# Restore context values before any await
for context_var in context_vars:
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
# Update our tracked values with any changes made during this iteration
for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get()
yield item
except StopAsyncIteration:
break
return wrapper()

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)

View file

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import subprocess
import sys
from termcolor import cprint
log = logging.getLogger(__name__)
import importlib
import json
from pathlib import Path
from llama_stack.core.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
env_name = ""
if image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env
if not env_name:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
file=sys.stderr,
)
return
def get_conda_prefix(env_name):
# Conda "base" environment does not end with "base" in the
# prefix, so should be handled separately.
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if os.path.basename(envpath) == env_name:
return envpath
return None
cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr)
conda_prefix = get_conda_prefix(env_name)
if not conda_prefix:
cprint(
f"Conda environment {env_name} does not exist.",
color="red",
file=sys.stderr,
)
return
build_file = Path(conda_prefix) / "llamastack-build.yaml"
if not build_file.exists():
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
file=sys.stderr,
)
return
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
env_name = image_name or current_venv
if not env_name:
cprint(
"No current virtual environment detected, please specify a virtual environment name with --image-name",
color="red",
file=sys.stderr,
)
return
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
script = importlib.resources.files("llama_stack") / "core/start_stack.sh"
run_args = [
script,
image_type,
env_name,
]
return run_args
def in_notebook():
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def run_command(command: list[str]) -> int:
"""
Run a command with interrupt handling and output capture.
Uses subprocess.run with direct stream piping for better performance.
Args:
command (list): The command to run.
Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
log.info("\nCtrl-C detected. Aborting...")
try:
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
# Run the command with stdout/stderr piped directly to system streams
result = subprocess.run(
command,
text=True,
check=False,
)
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")
return 1
except Exception as e:
log.exception(f"Unexpected error: {e}")
return 1
finally:
# Restore the original signal handler
signal.signal(signal.SIGINT, original_sigint)

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
class LlamaStackImageType(enum.Enum):
CONTAINER = "container"
CONDA = "conda"
VENV = "venv"

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))

View file

@ -0,0 +1,282 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
import logging
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is list or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
return False
def is_basemodel_without_fields(typ):
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
def can_recurse(typ):
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items():
if field_name in validator.info.fields:
validator.func(value)
return value
def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if get_origin(typ) is not Annotated:
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
def prompt_for_discriminated_union(
field_name,
typ,
existing_value,
):
if isinstance(typ, FieldInfo):
inner_type = typ.annotation
discriminator = typ.discriminator
default_value = typ.default
else:
args = get_args(typ)
inner_type = args[0]
discriminator = args[1].discriminator
default_value = args[1].default
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
if default_value is not None:
prompt += f" (default: {default_value})"
discriminator_value = input(f"{prompt}: ")
if discriminator_value == "" and default_value is not None:
discriminator_value = default_value
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
# We should add handling for the most common cases to tide us over.
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = getattr(existing_config, field_name) if existing_config else None
if existing_value:
default_value = existing_value
else:
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
is_required = field.is_required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
# this branch does not handle existing and default values yet
user_input = input(prompt + " ")
try:
value = field_type[user_input]
validated_value = manually_validate_field(config_type, field, value)
config_data[field_name] = validated_value
break
except KeyError:
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
continue
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
config_data[field_name] = prompt_for_discriminated_union(
field_name,
nested_type,
existing_value,
)
elif can_recurse(field_type):
log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:
log.error("This field is required. Please provide a value.")
continue
else:
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
value = None
else:
field_type = get_non_none_type(field_type)
value = user_input
# Handle List of primitives
elif is_list_of_primitives(field_type):
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
value = [element_type(item) for item in value]
except json.JSONDecodeError:
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
continue
except ValueError as e:
log.error(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError("Input must be a JSON-encoded dictionary")
except json.JSONDecodeError:
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
# For nested BaseModels, we assume a dictionary-like string input
import ast
value = field_type(**ast.literal_eval(user_input))
else:
value = field_type(user_input)
except ValueError:
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
continue
try:
# Validate the field using our manual validation function
validated_value = manually_validate_field(config_type, field_name, value)
config_data[field_name] = validated_value
break
except ValueError as e:
log.error(f"Validation error: {str(e)}")
return config_type(**config_data)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from enum import Enum
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
elif isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)