chore(package): migrate to src/ layout (#3920)

Migrates package structure to src/ layout following Python packaging
best practices.

All code moved from `llama_stack/` to `src/llama_stack/`. Public API
unchanged - imports remain `import llama_stack.*`.

Updated build configs, pre-commit hooks, scripts, and GitHub workflows
accordingly. All hooks pass, package builds cleanly.

**Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
Ashwin Bharambe 2025-10-27 12:02:21 -07:00 committed by GitHub
parent 98a5047f9d
commit 471b1b248b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
791 changed files with 2983 additions and 456 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import User
from .conditions import (
Condition,
ProtectedResource,
parse_conditions,
)
from .datatypes import (
AccessRule,
Action,
Scope,
)
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
if resource_scope == actual_resource:
return True
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
def matches_scope(
scope: Scope,
action: Action,
resource: str,
user: str | None,
) -> bool:
if scope.resource and not matches_resource(scope.resource, resource):
return False
if scope.principal and scope.principal != user:
return False
return action in scope.actions
def as_list(obj: Any) -> list[Any]:
if isinstance(obj, list):
return obj
return [obj]
def matches_conditions(
conditions: list[Condition],
resource: ProtectedResource,
user: User,
) -> bool:
for condition in conditions:
# must match all conditions
if not condition.matches(resource, user):
return False
return True
def default_policy() -> list[AccessRule]:
# for backwards compatibility, if no rules are provided, assume
# full access subject to previous attribute matching rules
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
),
]
def is_action_allowed(
policy: list[AccessRule],
action: Action,
resource: ProtectedResource,
user: User | None,
) -> bool:
# If user is not set, assume authentication is not enabled
if not user:
return True
if not len(policy):
policy = default_policy()
qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return False
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return False
else:
return False
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
if rule.when:
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
return True
elif rule.unless:
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
return True
else:
return True
# assume access is denied unless we find a rule that permits access
return False
class AccessDeniedError(RuntimeError):
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
self.action = action
self.resource = resource
self.user = user
message = _build_access_denied_message(action, resource, user)
super().__init__(message)
def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str:
"""Build detailed error message for access denied scenarios."""
if action and resource and user:
resource_info = f"{resource.type}::{resource.identifier}"
user_info = f"'{user.principal}'"
if user.attributes:
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
user_info += f" (attributes: {attrs})"
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
else:
message = "Insufficient permissions"
return message

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
class User(Protocol):
principal: str
attributes: dict[str, list[str]] | None
class ProtectedResource(Protocol):
type: str
identifier: str
owner: User
class Condition(Protocol):
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
class UserInOwnersList:
def __init__(self, name: str):
self.name = name
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
if (
hasattr(resource, "owner")
and resource.owner
and resource.owner.attributes
and self.name in resource.owner.attributes
):
return resource.owner.attributes[self.name]
else:
return None
def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
if value in user_values:
return True
return False
def __repr__(self):
return f"user in owners {self.name}"
class UserNotInOwnersList(UserInOwnersList):
def __init__(self, name: str):
super().__init__(name)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user not in owners {self.name}"
class UserWithValueInList:
def __init__(self, name: str, value: str):
self.name = name
self.value = value
def matches(self, resource: ProtectedResource, user: User) -> bool:
if user.attributes and self.name in user.attributes:
return self.value in user.attributes[self.name]
print(f"User does not have {self.value} in {self.name}")
return False
def __repr__(self):
return f"user with {self.value} in {self.name}"
class UserWithValueNotInList(UserWithValueInList):
def __init__(self, name: str, value: str):
super().__init__(name, value)
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not super().matches(resource, user)
def __repr__(self):
return f"user with {self.value} not in {self.name}"
class UserIsOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return resource.owner.principal == user.principal if resource.owner else False
def __repr__(self):
return "user is owner"
class UserIsNotOwner:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner or resource.owner.principal != user.principal
def __repr__(self):
return "user is not owner"
def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
case ["user", "is", "owner"]:
return UserIsOwner()
case ["user", "is", "not", "owner"]:
return UserIsNotOwner()
case ["user", "with", value, "in", name]:
return UserWithValueInList(name, value)
case ["user", "with", value, "not", "in", name]:
return UserWithValueNotInList(name, value)
case ["user", "in", "owners", name]:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case _:
raise ValueError(f"Invalid condition: {condition}")
def parse_conditions(conditions: list[str]) -> list[Condition]:
return [parse_condition(c) for c in conditions]

View file

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from typing import Self
from pydantic import BaseModel, model_validator
from .conditions import parse_conditions
class Action(StrEnum):
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
class Scope(BaseModel):
principal: str | None = None
actions: Action | list[Action]
resource: str | None = None
def _mutually_exclusive(obj, a: str, b: str):
if getattr(obj, a) and getattr(obj, b):
raise ValueError(f"{a} and {b} are mutually exclusive")
def _require_one_of(obj, a: str, b: str):
if not getattr(obj, a) and not getattr(obj, b):
raise ValueError(f"on of {a} or {b} is required")
class AccessRule(BaseModel):
"""Access rule based loosely on cedar policy language
A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all
requests.
A rule may also specify a condition, either a 'when' or an 'unless', with additional
constraints as to where the rule applies. The constraints supported at present are:
- 'user with <attr-value> in <attr-name>'
- 'user with <attr-value> not in <attr-name>'
- 'user is owner'
- 'user is not owner'
- 'user in owners <attr-name>'
- 'user not in owners <attr-name>'
Rules are tested in order to find a match. If a match is found, the request is
permitted or forbidden depending on the type of rule. If no match is found, the
request is denied. If no rules are specified, a rule that allows any action as
long as the resource attributes match the user attributes is added
(i.e. the previous behaviour is the default).
Some examples in yaml:
- permit:
principal: user-1
actions: [create, read, delete]
resource: model::*
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
actions: [read]
when: user in owner teams
description: any user has read access to any resource created by a member of their team
- forbid:
actions: [create, read, delete]
resource: vector_store::*
unless: user with admin in roles
description: only user with admin role can use vector_store resources
"""
permit: Scope | None = None
forbid: Scope | None = None
when: str | list[str] | None = None
unless: str | list[str] | None = None
description: str | None = None
@model_validator(mode="after")
def validate_rule_format(self) -> Self:
_require_one_of(self, "permit", "forbid")
_mutually_exclusive(self, "permit", "forbid")
_mutually_exclusive(self, "when", "unless")
if isinstance(self.when, list):
parse_conditions(self.when)
elif self.when:
parse_conditions([self.when])
if isinstance(self.unless, list):
parse_conditions(self.unless)
elif self.unless:
parse_conditions([self.unless])
return self

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import sys
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = get_logger(name=__name__, category="core")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
class ApiInput(BaseModel):
api: Api
provider: str
def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
if isinstance(config, DistributionTemplate):
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
external_provider_deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different - not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
# this ensures we install the top level module for our external providers
if provider_spec.module:
if isinstance(provider_spec.module, str):
external_provider_deps.append(provider_spec.module)
else:
external_provider_deps.extend(provider_spec.module)
if hasattr(provider_spec, "pip_packages"):
deps.extend(provider_spec.pip_packages)
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
raise ValueError("A stack's dependencies cannot have a container image")
normal_deps = []
special_deps = []
for package in deps:
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
special_deps.append(package)
else:
normal_deps.append(package)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps, _ = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
color="yellow",
file=sys.stderr,
)
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()
def build_image(
build_config: BuildConfig,
image_name: str,
distro_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
args = [
script,
"--distro-or-config",
distro_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.extend(["--run-config", run_config])
else:
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)
if return_code != 0:
log.error(
f"Failed to build target {image_name} with return code {return_code}",
)
return return_code

View file

@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol) -> type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
# TODO: make this more precise, same thing needs to happen in server.py
is_streaming = kwargs.get("stream", False)
if is_streaming:
return self._call_streaming(method_name, *args, **kwargs)
else:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
_, sig = self.routes[method_name]
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if j is None:
return None
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
data = json.loads(data)
if "error" in data:
cprint(data, color="red", file=sys.stderr)
continue
yield parse_obj_as(return_type, data)
except Exception as e:
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
raise RuntimeError(f"Method {method} has no webmethod decorators")
# Choose the preferred webmethod (non-deprecated if available)
preferred_webmethod = None
for wm in webmethods:
if not getattr(wm, "deprecated", False):
preferred_webmethod = wm
break
# If no non-deprecated found, use the first one
if preferred_webmethod is None:
preferred_webmethod = webmethods[0]
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
def convert(value):
if isinstance(value, list):
return [convert(v) for v in value]
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
params = {}
data = {}
if webmethod.method == "GET":
params.update(kwargs)
else:
data.update(convert(kwargs))
ret = dict(
method=webmethod.method or "POST",
url=url,
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
timeout=30,
)
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
_CLIENT_CLASSES[protocol] = APIClient
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None

37
src/llama_stack/core/common.sh Executable file
View file

@ -0,0 +1,37 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
cleanup() {
# For venv environments, no special cleanup is needed
# This function exists to avoid "function not found" errors
local env_name="$1"
echo "Cleanup called for environment: $env_name"
}
handle_int() {
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
exit 1
}
handle_exit() {
if [ $? -ne 0 ]; then
echo -e "\033[1;31mABORTING.\033[0m"
if [ -n "$ENVNAME" ]; then
cleanup "$ENVNAME"
fi
fi
}
# check if a command is present
is_command_available() {
command -v "$1" &>/dev/null
}

View file

@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from typing import Any
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.core.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.model_dump(),
)
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
provider_type=provider.provider_type,
config={},
),
)
)
logger.info("")
config.providers[api_str] = updated_providers
return config
def upgrade_from_routing_table(
config_dict: dict[str, Any],
) -> dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
# Add default storage config if not present
if "storage" not in config_dict:
config_dict["storage"] = {
"backends": {
"kv_default": {
"type": "kv_sqlite",
"db_path": "~/.llama/kvstore.db",
},
"sql_default": {
"type": "sql_sqlite",
"db_path": "~/.llama/sql_store.db",
},
},
"stores": {
"metadata": {
"namespace": "registry",
"backend": "kv_default",
},
"inference": {
"table_name": "inference_store",
"backend": "sql_default",
"max_write_queue_size": 10000,
"num_writers": 4,
},
"conversations": {
"table_name": "openai_conversations",
"backend": "sql_default",
},
},
}
return config_dict
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,314 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import secrets
import time
from typing import Any, Literal
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemInclude,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
logger = get_logger(name=__name__, category="openai_conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param run_config: Stack run configuration for resolving persistence
:param policy: Access control rules
"""
run_config: StackRunConfig
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
# Use conversations store reference from run config
conversations_ref = config.run_config.storage.stores.conversations
if not conversations_ref:
raise ValueError("storage.stores.conversations must be configured in run config")
base_sql_store = sqlstore_impl(conversations_ref)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.debug(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.debug(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
base_time = int(time.time())
for i, item in enumerate(items):
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
# make each timestamp unique to maintain order
created_at = base_time + i
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.debug(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list_items(
self,
conversation_id: str,
after: str | None = None,
include: list[ConversationItemInclude] | None = None,
limit: int | None = None,
order: Literal["asc", "desc"] | None = None,
) -> ConversationItemList:
"""List items in the conversation."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
# check if conversation exists
await self.get_conversation(conversation_id)
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order is not None and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = limit or 20
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.debug(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -0,0 +1,629 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any, Literal, Self
from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import (
KVStoreReference,
StorageBackendType,
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
super().__init__(principal=principal, attributes=attributes)
class ResourceWithOwner(Resource):
"""Extension of Resource that adds an optional owner, i.e. the user that created the
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects
class ModelWithOwner(Model, ResourceWithOwner):
pass
class ShieldWithOwner(Shield, ResourceWithOwner):
pass
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
pass
class DatasetWithOwner(Dataset, ResourceWithOwner):
pass
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
pass
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
| ShieldWithOwner
| VectorStoreWithOwner
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
container_image: str | None = None
routing_table_api: Api
module: str
provider_data_validator: str | None = Field(
default=None,
)
# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
container_image: str | None = None
router_api: Api
module: str
pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class BuildProvider(BaseModel):
provider_type: str
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel):
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: str | None = None
providers: dict[str, list[BuildProvider]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.
""",
)
class TelemetryConfig(BaseModel):
"""
Configuration for telemetry.
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
for env variables to configure the OpenTelemetry SDK.
Example:
```bash
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
```
"""
enabled: bool = Field(default=False, description="enable or disable telemetry")
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom"
KUBERNETES = "kubernetes"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "roles",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
class KubernetesAuthProviderConfig(BaseModel):
"""Configuration for Kubernetes authentication provider."""
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
api_server_url: str = Field(
default="https://kubernetes.default.svc",
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
)
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"username": "roles",
"groups": "roles",
},
description="Mapping of Kubernetes user claims to access attributes",
)
@field_validator("api_server_url")
@classmethod
def validate_api_server_url(cls, v):
parsed = urlparse(v)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
if parsed.scheme not in ["http", "https"]:
raise ValueError(f"api_server_url scheme must be http or https: {v}")
return v
@field_validator("claims_mapping")
@classmethod
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel):
"""Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
...,
description="Authentication provider configuration",
)
access_policy: list[AccessRule] = Field(
default=[],
description="Rules for determining access to resources",
)
class AuthenticationRequiredError(Exception):
pass
class QualifiedModel(BaseModel):
"""A qualified model identifier, consisting of a provider ID and a model ID."""
provider_id: str
model_id: str
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
default_provider_id: str | None = Field(
default=None,
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
)
default_embedding_model: QualifiedModel | None = Field(
default=None,
description="Default embedding model configuration for vector stores.",
)
class SafetyConfig(BaseModel):
"""Configuration for default moderations model."""
default_shield_id: str | None = Field(
default=None,
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
)
class QuotaPeriod(StrEnum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class RegisteredResources(BaseModel):
"""Registry of resources available in the distribution."""
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
description="Port to listen on",
ge=1024,
le=65535,
)
tls_certfile: str | None = Field(
default=None,
description="Path to TLS certificate file for HTTPS",
)
tls_keyfile: str | None = Field(
default=None,
description="Path to TLS key file for HTTPS",
)
tls_cafile: str | None = Field(
default=None,
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
)
auth: AuthenticationConfig | None = Field(
default=None,
description="Authentication configuration for the server",
)
host: str | None = Field(
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
image_name: str = Field(
...,
description="""
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
this could be just a hash
""",
)
container_image: str | None = Field(
default=None,
description="Reference to the container image if this package refers to a container",
)
apis: list[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
providers: dict[str, list[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
storage: StorageConfig = Field(
description="Catalog of named storage backends and references available to the stack",
)
registered_resources: RegisteredResources = Field(
default_factory=RegisteredResources,
description="Registry of resources available in the distribution",
)
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
vector_stores: VectorStoresConfig | None = Field(
default=None,
description="Configuration for vector stores, including default embedding model",
)
safety: SafetyConfig | None = Field(
default=None,
description="Configuration for default moderations model",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
@model_validator(mode="after")
def validate_server_stores(self) -> "StackRunConfig":
backend_map = self.storage.backends
stores = self.storage.stores
kv_backends = {
name
for name, cfg in backend_map.items()
if cfg.type
in {
StorageBackendType.KV_REDIS,
StorageBackendType.KV_SQLITE,
StorageBackendType.KV_POSTGRES,
StorageBackendType.KV_MONGODB,
}
}
sql_backends = {
name
for name, cfg in backend_map.items()
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
}
def _ensure_backend(reference, expected_set, store_name: str) -> None:
if reference is None:
return
backend_name = reference.backend
if backend_name not in backend_map:
raise ValueError(
f"{store_name} references unknown backend '{backend_name}'. "
f"Available backends: {sorted(backend_map)}"
)
if backend_name not in expected_set:
raise ValueError(
f"{store_name} references backend '{backend_name}' of type "
f"'{backend_map[backend_name].type.value}', but a backend of type "
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
)
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
_ensure_backend(stores.prompts, kv_backends, "storage.stores.prompts")
return self
class BuildConfig(BaseModel):
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(
default="venv",
description="Type of package to build (container | venv)",
)
image_name: str | None = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Path | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -0,0 +1,276 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
import os
from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
RemoteProviderSpec,
)
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:
return list(Api)
class AutoRoutedApiInfo(BaseModel):
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.benchmarks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_stores,
router_api=Api.vector_io,
),
]
def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
return spec
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
# Check if config has external providers
if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
return registry
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ListRoutesResponse,
RouteInfo,
VersionInfo,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self, config: DistributionInspectConfig, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_routes(self) -> ListRoutesResponse:
run_config: StackRunConfig = self.config.run_config
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, _ in endpoints
if e.methods is not None
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, _ in endpoints
if e.methods is not None
]
)
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,540 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
import logging # allow-direct-logging
import os
import sys
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
Stack,
get_stack_run_config_from_distro,
replace_env_vars,
)
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger, setup_logging
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")
T = TypeVar("T")
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
logger.error(f"Error converting list {value} into {item_type}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
logger.error(f"Error converting dict {value} into {val_type}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
# We should get rid of any non-discriminated unions in the API schema.
if origin is Union:
for union_type in get_args(annotation):
try:
return convert_to_pydantic(union_type, value)
except Exception:
continue
logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_distro_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def initialize(self):
"""
Deprecated method for backward compatibility.
"""
pass
def request(self, *args, **kwargs):
loop = self.loop
asyncio.set_event_loop(loop)
if kwargs.get("stream"):
def sync_generator():
try:
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return sync_generator()
else:
try:
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
return result
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# Initialize logging from environment variables first
setup_logging()
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
current_sinks = sinks_from_env.strip().lower().split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# distribution
config = get_stack_run_config_from_distro(config_path_or_distro_name)
self.config_path_or_distro_name = config_path_or_distro_name
self.config = config
self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
"Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n",
color="yellow",
file=sys.stderr,
)
if self.config_path_or_distro_name.endswith(".yaml"):
providers: dict[str, list[BuildProvider]] = {}
for api, run_providers in self.config.providers.items():
for provider in run_providers:
providers.setdefault(api, []).append(
BuildProvider(provider_type=provider.provider_type, module=provider.module)
)
providers = dict(providers)
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=providers,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
"yellow",
file=sys.stderr,
)
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e
assert self.impls is not None
if self.config.telemetry.enabled:
setup_logger(Telemetry())
if not os.environ.get("PYTEST_CURRENT_TEST"):
console = Console()
console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:")
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
self.route_impls = initialize_route_impls(self.impls)
return True
async def request(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
if self.route_impls is None:
raise ValueError("Client not initialized. Please call initialize() first.")
# Create headers with provider data if available
headers = options.headers or {}
if self.provider_data:
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming(
self,
*,
cast_to: Any,
options: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
if hasattr(options, "extra_json") and options.extra_json:
body |= options.extra_json
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
status_code = httpx.codes.OK
if options.method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT
if status_code == httpx.codes.NO_CONTENT:
json_content = ""
mock_response = httpx.Response(
status_code=status_code,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(filtered_body),
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen():
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(body),
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
body = body or {}
exclude_params = exclude_params or set()
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
base_type = get_args(param_type)[0]
return {param.name: base_type(**body)}
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# Check if there's an unwrapped body parameter among multiple parameters
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
unwrapped_body_param = None
for param in params_list:
if is_unwrapped_body_param(param.annotation):
unwrapped_body_param = param
break
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
# handle unwrapped body parameter after processing all named parameters
if unwrapped_body_param:
base_type = get_args(unwrapped_body_param.annotation)[0]
# extract only keys not already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
return converted_body

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
class PromptServiceConfig(BaseModel):
"""Configuration for the built-in prompt service.
:param run_config: Stack run configuration containing distribution info
"""
run_config: StackRunConfig
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
"""Get the prompt service implementation."""
impl = PromptServiceImpl(config, deps)
await impl.initialize()
return impl
class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
# Use prompts store reference from run config
prompts_ref = self.config.run_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
if resolved_version is None:
raise ValueError(f"Prompt {prompt_id}:default not found")
return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
"is_default": prompt.is_default,
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = []
prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None:
variables = []
prompt_versions = await self.list_prompt_versions(prompt_id)
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
if version and latest_prompt.version != version:
raise ValueError(
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
)
current_version = latest_prompt.version if version is None else version
new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
if set_as_default:
await self.set_default_version(prompt_id, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt."""
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
default_version = None
prompts = []
for key in keys:
data = await self.kvstore.get(key)
if key.endswith(":default"):
default_version = data
else:
if data:
prompt_obj = self._deserialize_prompt(data)
prompts.append(prompt_obj)
if not prompts:
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)

View file

@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
for api, providers in safe_config.providers.items():
for p in providers:
# Skip providers that are not enabled
if p.provider_id is None:
continue
ret.append(
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
)
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: dict[str, dict[str, HealthResponse]] = {}
# The timeout has to be long enough to allow all the providers to be checked, especially in
# the case of the inference router health check since it checks all registered inference
# providers.
# The timeout must not be equal to the one set by health method for a given implementation,
# otherwise we will miss some providers.
timeout = 3.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = get_logger(name=__name__, category="core")
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
self.provider_data = provider_data or {}
if user:
self.provider_data["__authenticated_user"] = user
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = PROVIDER_DATA_VAR.get()
if not val:
return None
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return None
try:
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!")
return None
def request_provider_data_context(
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_authenticated_user() -> User | None:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -0,0 +1,482 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import importlib.metadata
import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.store import DistributionRegistry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
pass
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_stores: VectorStore,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.benchmarks: Benchmarks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
1. Validating and organizing providers.
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve:
continue
specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.core.routers",
api_dependencies=[],
deps__=[f"inner-{info.router_api.value}"],
),
)
}
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
return specs
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers
async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
internal_impls: dict[Api, Any] | None = None,
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
# Skip providers that are not enabled
if provider.provider_id is None:
continue
try:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
except KeyError as e:
missing_api = e.args[0]
raise RuntimeError(
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
f"required dependency '{missing_api.value}' is not available. "
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
) from e
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(
providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv
visited.add(api_str)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for dep in deps:
if dep not in visited and dep in providers_with_specs:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(api_str)
visited: set[str] = set()
stack: list[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
config = None
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
has_alpha_api = False
for webmethod in value.__webmethods__:
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
has_alpha_api = True
break
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
if has_alpha_api:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: list[str],
) -> dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import (
AccessRule,
RoutedProtocol,
)
from llama_stack.core.stack import StackRunConfig
from llama_stack.core.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def get_routing_table_impl(
api: Api,
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
from ..routing_tables.models import ModelsRoutingTable
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_stores import VectorStoresRoutingTable
api_to_tables = {
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
"vector_stores": VectorStoresRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
await impl.initialize()
return impl
async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
) -> Any:
from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter
from .inference import InferenceRouter
from .safety import SafetyRouter
from .tool_runtime import ToolRuntimeRouter
from .vector_io import VectorIORouter
api_to_routers = {
"vector_io": VectorIORouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
api_to_dep_impl = {}
# TODO: move pass configs to routers instead
if api == Api.inference:
inference_ref = run_config.storage.stores.inference
if not inference_ref:
raise ValueError("storage.stores.inference must be configured in run config")
inference_store = InferenceStore(
reference=inference_ref,
policy=policy,
)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
elif api == Api.safety:
api_to_dep_impl["safety_config"] = run_config.safety
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,
job_id,
)

View file

@ -0,0 +1,608 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from fastapi import Body
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
Order,
RerankResponse,
StopReason,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
store: InferenceStore | None = None,
telemetry_enabled: bool = False,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry_enabled = telemetry_enabled
self.store = store
if self.telemetry_enabled:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
if self.store:
try:
await self.store.shutdown()
except Exception as e:
logger.warning(f"Error during InferenceStore shutdown: {e}")
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry_enabled:
for metric in metrics:
enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
response = await provider.openai_completion(params)
if self.telemetry_enabled:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_chat_completion(
self,
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if params.tools:
for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if params.tool_choice == "none" and params.tools is not None:
params.tool_choice = None
params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if params.stream:
response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=params.messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry_enabled:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
)
for metric in metrics:
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_embeddings(
self,
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
) -> OpenAIEmbeddingsResponse:
logger.debug(
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
)
model_obj = await self._get_model(params.model, ModelType.embedding)
# Update model to use resolved identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(params)
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message.tool_calls = None
return response
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
async def stream_tokens_and_compute_metrics(
self,
response,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
if hasattr(chunk, "event"): # only ChatCompletions have .event
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
complete = True
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_prompt_format=tool_prompt_format,
)
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
if complete:
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for streaming completion metrics
if self.telemetry_enabled:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
):
if isinstance(response, ChatCompletionResponse):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
if self.telemetry_enabled:
# Log metrics in the new span context
completion_metrics = self._construct_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
# Fallback if no telemetry
metrics = self._construct_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
response: AsyncIterator[OpenAIChatCompletionChunk],
model: Model,
messages: list[OpenAIMessageParam] | None = None,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in response:
# Skip None chunks
if chunk is None:
continue
# Capture ID and created timestamp from first chunk
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
# Accumulate choice data for final assembly
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": "stop",
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
completion_text = ""
for choice_data in choices_data.values():
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model=model,
)
for metric in metrics:
enqueue_event(metric)
yield chunk
finally:
# Store the final assembled completion
if id and self.store and messages:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"],
function=OpenAIChatCompletionToolCallFunction(
name=func_name, arguments=func_args
),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(time.time()),
model=model.identifier,
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(self.store.store_chat_completion(final_response, messages))

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
safety_config: SafetyConfig | None = None,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
self.safety_config = safety_config
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
list_shields_response = await self.routing_table.list_shields()
shields = list_shields_response.data
selected_shield: Shield | None = None
provider_model: str | None = model
if model:
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
if not matches:
raise ValueError(
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
)
if len(matches) > 1:
raise ValueError(
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
)
selected_shield = matches[0]
else:
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
if not default_shield_id:
raise ValueError(
"No moderation model specified and no default_shield_id configured in safety config: select model "
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
)
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
if selected_shield is None:
raise ValueError(
f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}."
)
provider_model = selected_shield.provider_resource_id
shield_id = selected_shield.identifier
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_moderation(
input=input,
model=provider_model,
)
return response

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -0,0 +1,442 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import uuid
from typing import Annotated, Any
from fastapi import Body
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the embedding dimension for a specific embedding model."""
all_models = await self.routing_table.get_all_with_type("model")
for model in all_models:
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
dimension = model.metadata.get("embedding_dimension")
if dimension is None:
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
return int(dimension)
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
doc_ids = [chunk.document_id for chunk in chunks[:3]]
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, "
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject:
# Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {}
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
# Use default embedding model if not specified
if (
embedding_model is None
and self.vector_stores_config
and self.vector_stores_config.default_embedding_model is not None
):
# Construct the full model ID with provider prefix
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
model_id = self.vector_stores_config.default_embedding_model.model_id
embedding_model = f"{embedding_provider_id}/{model_id}"
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)
if num_providers == 0:
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
# Use default configured provider
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
default_provider = self.vector_stores_config.default_provider_id
if default_provider in available_providers:
provider_id = default_provider
logger.debug(f"Using configured default vector store provider: {provider_id}")
else:
raise ValueError(
f"Configured default vector store provider '{default_provider}' not found. "
f"Available providers: {available_providers}"
)
else:
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_store_id = f"vs_{uuid.uuid4()}"
registered_vector_store = await self.routing_table.register_vector_store(
vector_store_id=vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_store_id=vector_store_id,
vector_store_name=params.name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
# Update model_extra with registered values so provider uses the already-registered vector_store
if params.model_extra is None:
params.model_extra = {}
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_store.provider_id
if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None:
params.model_extra["embedding_dimension"] = embedding_dimension
return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_stores = await self.routing_table.get_all_with_type("vector_store")
all_stores = []
for vector_store in vector_stores:
try:
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
continue
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next(
(i for i, store in enumerate(all_stores) if store.id == before),
len(all_stores),
)
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = limited_stores[0].id if limited_stores else None
last_id = limited_stores[-1].id if limited_stores else None
return VectorStoreListResponse(
data=limited_stores,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject:
logger.debug(
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.core.datatypes import (
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithOwner(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
async def unregister_benchmark(self, benchmark_id: str) -> None:
existing_benchmark = await self.get_benchmark(benchmark_id)
await self.unregister_object(existing_benchmark)

View file

@ -0,0 +1,254 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
AccessRule,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithOwner,
)
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_store(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_store(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.safety:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.eval:
return await p.unregister_benchmark(obj.identifier)
elif api == Api.scoring:
return await p.unregister_scoring_function(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
policy: list[AccessRule],
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
self.policy = policy
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_store_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_stores import VectorStoresRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorStoresRoutingTable):
return ("VectorIO", "vector_store")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
user = get_authenticated_user()
if not is_action_allowed(self.policy, "delete", obj, user):
raise AccessDeniedError("delete", obj, user)
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError("create", obj, creator)
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def assert_action_allowed(
self,
action: Action,
type: str,
identifier: str,
) -> None:
"""Fetch a registered object by type/identifier and enforce the given action permission."""
obj = await self.get_object_by_identifier(type, identifier)
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
]
return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
model = await routing_table.get_object_by_identifier("model", model_id)
if not model:
raise ModelNotFoundError(model_id)
return model

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Any
from llama_stack.apis.common.errors import DatasetNotFoundError
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise DatasetNotFoundError(dataset_id)
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata and metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithOwner(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
await self.unregister_object(dataset)

View file

@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
refresh = refresh or provider_id not in self.listed_providers
if not refresh:
continue
try:
models = await provider.list_models()
except Exception as e:
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
if model.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
async def has_model(self, model_id: str) -> bool:
"""
Check if a model exists in the routing table.
:param model_id: The model identifier to check
:return: True if the model exists, False otherwise
"""
try:
await lookup_model(self, model_id)
return True
except ModelNotFoundError:
return False
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model)
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.provider_resource_id in model_ids:
# avoid overwriting a non-provider-registered model entry
continue
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.core.datatypes import (
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithOwner(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
await self.unregister_object(existing_scoring_fn)

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.core.datatypes import (
ShieldWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithOwner(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield
async def unregister_shield(self, identifier: str) -> None:
existing_shield = await self.get_shield(identifier)
await self.unregister_object(existing_shield)

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id:
routing_key = toolgroup_id
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
try:
await self._index_tools(toolgroup)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
raise
except Exception as e:
# Other errors that the client cannot fix are logged and
# those specific toolgroups are skipped.
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
logger.debug(e, exc_info=True)
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
tooldefs = tooldefs_response.data
for t in tooldefs:
t.toolgroup_id = toolgroup.identifier
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
for tool in tooldefs:
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ToolGroupNotFoundError(toolgroup_id)
return tool_group
async def get_tool(self, tool_name: str) -> ToolDef:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.name == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithOwner(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
await self.unregister_object(await self.get_tool_group(toolgroup_id))
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,292 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
# Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorStoresRoutingTable(CommonRoutingTableImpl):
"""Internal routing table for vector_store operations.
Does not inherit from VectorStores to avoid exposing public API endpoints.
Only provides internal routing functionality for VectorIORouter.
"""
# Internal methods only - no public API exposure
async def register_vector_store(
self,
vector_store_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_store_id: str | None = None,
vector_store_name: str | None = None,
) -> Any:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
vector_store = VectorStoreWithOwner(
identifier=vector_store_id,
type=ResourceType.vector_store.value,
provider_id=provider_id,
provider_resource_id=provider_vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
vector_store_name=vector_store_name,
)
await self.register_object(vector_store)
return vector_store
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_store(vector_store_id)
return result
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Remove the vector store from the routing table registry."""
try:
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
if vector_store_obj:
await self.unregister_object(vector_store_obj)
except Exception as e:
# Log the error but don't fail the operation
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import httpx
from aiohttp import hdrs
from llama_stack.core.datatypes import AuthenticationConfig, User
from llama_stack.core.request_headers import user_from_scope
from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:
"""Middleware that authenticates requests using configured authentication provider.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Uses the configured auth provider to validate the token
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Unauthenticated Access:
Endpoints can opt out of authentication by setting require_authentication=False
in their @webmethod decorator. This is typically used for operational endpoints
like /health and /version to support monitoring, load balancers, and observability tools.
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
Authentication Request Format for Custom Auth Provider:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Token Validation:
Each provider implements its own token validation logic:
- Kubernetes: Uses TokenReview API to validate service account tokens
- Custom: Sends token to custom endpoint for validation
Attribute-Based Access Control:
The attributes returned by the auth provider are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth provider doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Find the route and check if authentication is required
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
webmethod = None
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass here to run auth anyways
pass
# If webmethod explicitly sets require_authentication=False, allow without auth
if webmethod and webmethod.require_authentication is False:
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
return await self.app(scope, receive, send)
# Handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Invalid Authorization header format")
token = auth_header.split("Bearer ", 1)[1]
# Validate token and get access attributes
try:
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
# Scope-based API access control
if webmethod and webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message, status=401):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -0,0 +1,494 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ssl
from abc import ABC, abstractmethod
from urllib.parse import parse_qs, urljoin, urlparse
import httpx
import jwt
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
from llama_stack.core.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubTokenAuthConfig,
KubernetesAuthProviderConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes."""
pass
@abstractmethod
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
# First try dot notation for nested traversal (e.g., "resource_access.llamastack.roles")
# Then fall back to literal key with dots (e.g., "my.dotted.key")
claim: object = claims
keys = claim_key.split(".")
for key in keys:
if isinstance(claim, dict) and key in claim:
claim = claim[key]
else:
claim = None
break
if claim is None and claim_key in claims:
# Fall back to checking if claim_key exists as a literal key
claim = claims[claim_key]
if claim is None:
continue
if isinstance(claim, list):
values = claim
elif isinstance(claim, str):
values = claim.split()
else:
continue
if attribute_key in attributes:
attributes[attribute_key].extend(values)
else:
attributes[attribute_key] = values
return attributes
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_client: jwt.PyJWKClient | None = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
def _get_jwks_client(self) -> jwt.PyJWKClient:
if self._jwks_client is None:
ssl_context = None
if not self.config.verify_tls:
# Disable SSL verification if verify_tls is False
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif self.config.tls_cafile:
# Use custom CA file if provided
ssl_context = ssl.create_default_context(
cafile=self.config.tls_cafile.as_posix(),
)
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
# to the JWK endpoint, we must use the token in the config to authenticate
headers = {}
if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient(
self.config.jwks.uri if self.config.jwks else None,
cache_keys=True,
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
headers=headers,
ssl_context=ssl_context,
)
return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
try:
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
signing_key = jwks_client.get_signing_key_from_jwt(token)
algorithm = jwt.get_unverified_header(token)["alg"]
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
# Decode and verify the JWT
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
except Exception as exc:
raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != httpx.codes.OK:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthConfig):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint."""
if scope is None:
scope = {}
headers = dict(scope.get("headers", []))
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=token,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != httpx.codes.OK:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
class GitHubTokenAuthProvider(AuthProvider):
"""
GitHub token authentication provider that validates GitHub access tokens directly.
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
them against the GitHub API to get user information.
"""
def __init__(self, config: GitHubTokenAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub token by calling the GitHub API.
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
"""
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]
github_data = {
"login": user_info["user"]["login"],
"id": str(user_info["user"]["id"]),
"organizations": user_info.get("organizations", []),
}
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
"""Fetch user info and organizations from GitHub API."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "llama-stack",
}
async with httpx.AsyncClient() as client:
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
user_response.raise_for_status()
user_data = user_response.json()
return {
"user": user_data,
}
class KubernetesAuthProvider(AuthProvider):
"""
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
This provider integrates with Kubernetes API server by using the
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
"""
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
def _httpx_verify_value(self) -> bool | str:
"""
Build the value for httpx's `verify` parameter.
- False disables verification.
- Path string points to a CA bundle.
- True uses system defaults.
"""
if not self.config.verify_tls:
return False
if self.config.tls_cafile:
return self.config.tls_cafile.as_posix()
return True
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
# Build the Kubernetes SelfSubjectReview API endpoint URL
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
# Create SelfSubjectReview request body
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
verify = self._httpx_verify_value()
try:
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
response = await client.post(
review_api_url,
json=review_request,
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
)
if response.status_code == httpx.codes.UNAUTHORIZED:
raise TokenValidationError("Invalid token")
if response.status_code != httpx.codes.CREATED:
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
raise TokenValidationError(f"Token validation failed: {response.status_code}")
review_response = response.json()
# Extract user information from SelfSubjectReview response
status = review_response.get("status", {})
if not status:
raise ValueError("No status found in SelfSubjectReview response")
user_info = status.get("userInfo", {})
if not user_info:
raise ValueError("No userInfo found in SelfSubjectReview response")
username = user_info.get("username")
if not username:
raise ValueError("No username found in SelfSubjectReview response")
# Build user attributes from Kubernetes user info
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
return User(
principal=username,
attributes=user_attributes,
)
except httpx.TimeoutException:
logger.warning("Kubernetes SelfSubjectReview API request timed out")
raise ValueError("Token validation timeout") from None
except Exception as e:
logger.warning(f"Error during token validation: {str(e)}")
raise ValueError(f"Token validation error: {str(e)}") from e
async def close(self):
"""Close any resources."""
pass
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_config = config.provider_config
if isinstance(provider_config, CustomAuthConfig):
return CustomAuthProvider(provider_config)
elif isinstance(provider_config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config)
elif isinstance(provider_config, KubernetesAuthProviderConfig):
return KubernetesAuthProvider(provider_config)
else:
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.
- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreReference,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
backend_config = _KVSTORE_BACKENDS.get(self.kv_config.backend)
if backend_config and backend_config.type == StorageBackendType.KV_SQLITE:
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests
current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"
try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(UTC) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")
return await self.app(scope, receive, send)
async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import re
from collections.abc import Callable
from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
continue
# Create routes for each webmethod decorator
for webmethod in webmethods:
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
routes.append(
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
apis[api] = routes
return apis
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
api_to_routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_routes in api_to_routes.items():
if api not in impls:
continue
for route, webmethod in api_routes:
impl = impls[api]
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
route.path,
webmethod,
)
return route_impls
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises:
ValueError: If no matching endpoint is found
"""
impls = route_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}")

View file

@ -0,0 +1,536 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import concurrent.futures
import functools
import inspect
import json
import logging # allow-direct-logging
import os
import sys
import traceback
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
import httpx
import rich.pretty
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
replace_env_vars,
)
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import LoggingConfig, get_logger, setup_logging
from llama_stack.providers.datatypes import Api
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
from .tracing import TracingMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.model_dump_json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.errors())
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=httpx.codes.BAD_REQUEST,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ConflictError):
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, ConnectionError | httpx.ConnectError):
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
# This preserves the actual HTTP status code from the provider
status_code = exc.status_code
detail = str(exc)
return HTTPException(status_code=status_code, detail=detail)
else:
return HTTPException(
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
detail="Internal server error: An unexpected error occurred.",
)
class StackApp(FastAPI):
"""
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager
async def lifespan(app: StackApp):
server_version = parse_version("llama-stack")
logger.info(f"Starting up Llama Stack server (version: {server_version})")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield
logger.info("Shutting down")
await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
# If there's a stream parameter at top level, use it
if "stream" in kwargs:
return kwargs["stream"]
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen_coroutine):
event_gen = None
try:
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
except asyncio.CancelledError:
logger.info("Generator cancelled")
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
test_context_token = None
test_context_var = None
reset_test_context_fn = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
test_context_var = TEST_CONTEXT
reset_test_context_fn = reset_test_context
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
if test_context_var is not None:
context_vars.append(test_context_var)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
if method.upper() == "DELETE" and result is None:
return Response(status_code=httpx.codes.NO_CONTENT)
return result
except Exception as e:
if logger.isEnabledFor(logging.INFO):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None and reset_test_context_fn is not None:
reset_test_context_fn(test_context_token)
sig = inspect.signature(func)
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data
new_params = (
[new_params[0]]
+ [
(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
)
route_handler.__signature__ = sig.replace(parameters=new_params)
return route_handler
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": httpx.codes.UPGRADE_REQUIRED,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
}
}
).encode()
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def create_app() -> StackApp:
"""Create and configure the FastAPI application.
This factory function reads configuration from environment variables:
- LLAMA_STACK_CONFIG: Path to config file (required)
Returns:
Configured StackApp instance.
"""
# Initialize logging from environment variables first
setup_logging()
config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
config_file = resolve_config_or_distro(config_file, Mode.RUN)
# Load and process configuration
logger_config = None
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config)
app = StackApp(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
config=config,
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
impls = app.stack.impls
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if config.telemetry.enabled:
setup_logger(Telemetry())
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route, _ in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
method.lower(),
route.path,
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
if config.telemetry.enabled:
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}
params = [param.split(":")[0] for param in params]
return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.core.telemetry.tracing import end_trace, start_trace
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core::server")
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
# Log deprecation warning if route is deprecated
if getattr(webmethod, "deprecated", False):
logger.warning(
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
f"This route is deprecated and may be removed in a future version. "
f"Please check the docs for the supported version."
)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()

View file

@ -0,0 +1,572 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib.resources
import os
import re
import tempfile
from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageBackendConfig,
StorageConfig,
)
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
Inference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
VectorIO,
Eval,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
Conversations,
):
pass
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config.registered_resources, rsrc)
if api not in impls:
continue
method = getattr(impls[api], register_method)
for obj in objects:
if hasattr(obj, "provider_id"):
# Do not register models on disabled providers
if not obj.provider_id or obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
"""Validate vector stores configuration."""
if vector_stores_config is None:
return
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
if safety_config is None or safety_config.default_shield_id is None:
return
if Api.shields not in impls:
raise ValueError("Safety configuration requires the shields API to be enabled")
if Api.safety not in impls:
raise ValueError("Safety configuration requires the safety API to be enabled")
shields_impl = impls[Api.shields]
response = await shields_impl.list_shields()
shields_by_id = {shield.identifier: shield for shield in response.data}
default_shield_id = safety_config.default_shield_id
# don't validate if there are no shields registered
if shields_by_id and default_shield_id not in shields_by_id:
available = sorted(shields_by_id)
raise ValueError(
f"Configured default_shield_id '{default_shield_id}' not found among registered shields."
f" Available shields: {available}"
)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
f"or ensure the environment variable is set."
)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
# Special handling for providers: first resolve the provider_id to check if provider
# is disabled so that we can skip config env variable expansion and avoid validation errors
if isinstance(v, dict) and "provider_id" in v:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
logger.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing
pass
# Normal processing for non-disabled providers
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, str):
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
def get_env_var(match: re.Match):
env_var = match.group(1)
operator = match.group(2) # '=' for default, '+' for conditional
value_expr = match.group(3)
env_value = os.environ.get(env_var)
if operator == "=": # Default value syntax: ${env.FOO:=default}
# If the env is set like ${env.FOO:=default} then use the env value when set
if env_value:
value = env_value
else:
# If the env is not set, look for a default value
# value_expr returns empty string (not None) when not matched
# This means ${env.FOO:=} and it's accepted and returns empty string - just like bash
if value_expr == "":
return ""
else:
value = value_expr
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
# If the env is set like ${env.FOO:+value_if_set} then use the value_if_set
if env_value:
if value_expr:
value = value_expr
# This means ${env.FOO:+}
else:
# Just like bash, this doesn't care whether the env is set or not and applies
# the value, in this case the empty string
return ""
else:
# Just like bash, this doesn't care whether the env is set or not, since it's not set
# we return an empty string
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values
return os.path.expanduser(value)
try:
result = re.sub(pattern, get_env_var, config)
# Only apply type conversion if substitution actually happened
if result != config:
return _convert_string_to_proper_type(result)
return result
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
if value == "":
return None
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
prompts_impl = PromptServiceImpl(
PromptServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
def _initialize_storage(run_config: StackRunConfig):
kv_backends: dict[str, StorageBackendConfig] = {}
sql_backends: dict[str, StorageBackendConfig] = {}
for backend_name, backend_config in run_config.storage.backends.items():
type = backend_config.type.value
if type.startswith("kv_"):
kv_backends[backend_name] = backend_config
elif type.startswith("sql_"):
sql_backends[backend_name] = backend_config
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
self.provider_registry = provider_registry
self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
_initialize_storage(self.run_config)
stores = self.run_config.storage.stores
if not stores.metadata:
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
internal_impls = {}
add_internal_implementations(internal_impls, self.run_config)
impls = await resolve_impls(
self.run_config,
self.provider_registry or get_provider_registry(self.run_config),
dist_registry,
policy,
internal_impls,
)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
await validate_safety_config(self.run_config.safety, impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
async def shutdown(self):
for impl in self.impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables:
await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_distro(distro: str) -> StackRunConfig:
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml"
with importlib.resources.as_file(distro_path) as path:
if not path.exists():
raise ValueError(f"Distribution '{distro}' not found at {distro_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
storage=StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=f"{distro_dir}/kvstore.db"),
"sql_default": SqliteSqlStoreConfig(db_path=f"{distro_dir}/sql_store.db"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
),
),
)
return config

View file

@ -0,0 +1,117 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
exit 1
fi
env_type="$1"
shift
env_path_or_name="$1"
container_image="localhost/$env_path_or_name"
shift
port="$1"
shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
# Initialize variables
yaml_config=""
other_args=""
# Process remaining arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
else
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Check if yaml_config is required
if [[ "$env_type" == "venv" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv environment${NC}" >&2
exit 1
fi
PYTHON_BINARY="python"
case "$env_type" in
"venv")
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
fi
;;
*)
# Handle unsupported env_types here
echo -e "${RED}Error: Unsupported environment type '$env_type'. Only 'venv' is supported.${NC}" >&2
exit 1
;;
esac
if [[ "$env_type" == "venv" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="$yaml_config"
else
yaml_config_arg=""
fi
llama stack run \
$yaml_config_arg \
--port "$port" \
$other_args
elif [[ "$env_type" == "container" ]]; then
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
echo -e "Please refer to the documentation for more information: https://llamastack.github.io/latest/distributions/building_distro.html#llama-stack-build"
exit 1
fi

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,287 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from abc import abstractmethod
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
class StorageBackendType(StrEnum):
KV_REDIS = "kv_redis"
KV_SQLITE = "kv_sqlite"
KV_POSTGRES = "kv_postgres"
KV_MONGODB = "kv_mongodb"
SQL_SQLITE = "sql_sqlite"
SQL_POSTGRES = "sql_postgres"
class CommonConfig(BaseModel):
namespace: str | None = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
class RedisKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_REDIS] = StorageBackendType.KV_REDIS
host: str = "localhost"
port: int = 6379
@property
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@classmethod
def pip_packages(cls) -> list[str]:
return ["redis"]
@classmethod
def sample_run_config(cls):
return {
"type": StorageBackendType.KV_REDIS.value,
"host": "${env.REDIS_HOST:=localhost}",
"port": "${env.REDIS_PORT:=6379}",
}
class SqliteKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_SQLITE] = StorageBackendType.KV_SQLITE
db_path: str = Field(
description="File path for the sqlite database",
)
@classmethod
def pip_packages(cls) -> list[str]:
return ["aiosqlite"]
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
"type": StorageBackendType.KV_SQLITE.value,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
class PostgresKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_POSTGRES] = StorageBackendType.KV_POSTGRES
host: str = "localhost"
port: int | str = 5432
db: str = "llamastack"
user: str
password: str | None = None
ssl_mode: str | None = None
ca_cert_path: str | None = None
table_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
return {
"type": StorageBackendType.KV_POSTGRES.value,
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
}
@classmethod
@field_validator("table_name")
def validate_table_name(cls, v: str) -> str:
# PostgreSQL identifiers rules:
# - Must start with a letter or underscore
# - Can contain letters, numbers, and underscores
# - Maximum length is 63 bytes
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
if not re.match(pattern, v):
raise ValueError(
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
)
if len(v) > 63:
raise ValueError("Table name must be less than 63 characters")
return v
@classmethod
def pip_packages(cls) -> list[str]:
return ["psycopg2-binary"]
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[StorageBackendType.KV_MONGODB] = StorageBackendType.KV_MONGODB
host: str = "localhost"
port: int = 27017
db: str = "llamastack"
user: str | None = None
password: str | None = None
collection_name: str = "llamastack_kvstore"
@classmethod
def pip_packages(cls) -> list[str]:
return ["pymongo"]
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {
"type": StorageBackendType.KV_MONGODB.value,
"host": "${env.MONGODB_HOST:=localhost}",
"port": "${env.MONGODB_PORT:=5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
}
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@classmethod
def pip_packages(cls) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[StorageBackendType.SQL_SQLITE] = StorageBackendType.SQL_SQLITE
db_path: str = Field(
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
)
@property
def engine_str(self) -> str:
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return {
"type": StorageBackendType.SQL_SQLITE.value,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["aiosqlite"]
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[StorageBackendType.SQL_POSTGRES] = StorageBackendType.SQL_POSTGRES
host: str = "localhost"
port: int | str = 5432
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["asyncpg"]
@classmethod
def sample_run_config(cls, **kwargs):
return {
"type": StorageBackendType.SQL_POSTGRES.value,
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
}
# reference = (backend_name, table_name)
class SqlStoreReference(BaseModel):
"""A reference to a 'SQL-like' persistent store. A table name must be provided."""
table_name: str = Field(
description="Name of the table to use for the SqlStore",
)
backend: str = Field(
description="Name of backend from storage.backends",
)
# reference = (backend_name, namespace)
class KVStoreReference(BaseModel):
"""A reference to a 'key-value' persistent store. A namespace must be provided."""
namespace: str = Field(
description="Key prefix for KVStore backends",
)
backend: str = Field(
description="Name of backend from storage.backends",
)
StorageBackendConfig = Annotated[
RedisKVStoreConfig
| SqliteKVStoreConfig
| PostgresKVStoreConfig
| MongoDBKVStoreConfig
| SqliteSqlStoreConfig
| PostgresSqlStoreConfig,
Field(discriminator="type"),
]
class InferenceStoreReference(SqlStoreReference):
"""Inference store configuration with queue tuning."""
max_write_queue_size: int = Field(
default=10000,
description="Max queued writes for inference store",
)
num_writers: int = Field(
default=4,
description="Number of concurrent background writers",
)
class ResponsesStoreReference(InferenceStoreReference):
"""Responses store configuration with queue tuning."""
class ServerStoresConfig(BaseModel):
metadata: KVStoreReference | None = Field(
default=None,
description="Metadata store configuration (uses KV backend)",
)
inference: InferenceStoreReference | None = Field(
default=None,
description="Inference store configuration (uses SQL backend)",
)
conversations: SqlStoreReference | None = Field(
default=None,
description="Conversations store configuration (uses SQL backend)",
)
responses: ResponsesStoreReference | None = Field(
default=None,
description="Responses store configuration (uses SQL backend)",
)
prompts: KVStoreReference | None = Field(
default=None,
description="Prompts store configuration (uses KV backend)",
)
class StorageConfig(BaseModel):
backends: dict[str, StorageBackendConfig] = Field(
description="Named backend configurations (e.g., 'default', 'cache')",
)
stores: ServerStoresConfig = Field(
default_factory=lambda: ServerStoresConfig(),
description="Named references to storage backends used by the stack core",
)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .registry import * # noqa: F401 F403

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from contextlib import asynccontextmanager
from typing import Protocol
import pydantic
from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
async def delete(self, type: str, identifier: str) -> None: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v10"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def initialize(self) -> None:
pass
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str:
return None
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return obj
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
if existing_obj and existing_obj != obj:
raise ValueError(
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
"Unregister it first if you want to replace it."
)
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return True
async def delete(self, type: str, identifier: str) -> None:
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
cache[cache_key] = obj
self._initialized = True
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
return self.cache.get((type, identifier), None)
async def get_all(self) -> list[RoutableObjectWithProvider]:
await self._ensure_initialized()
async with self._locked_cache() as cache:
return list(cache.values())
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
return cache.get(cache_key, None)
async def register(self, obj: RoutableObjectWithProvider) -> bool:
await self._ensure_initialized()
success = await super().register(obj)
if success:
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return success
async def update(self, obj: RoutableObjectWithProvider) -> None:
await super().update(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return obj
async def delete(self, type: str, identifier: str) -> None:
await super().delete(type, identifier)
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
del cache[cache_key]
async def create_dist_registry(
metadata_store: KVStoreReference, image_name: str
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
dist_kvstore = await kvstore_impl(metadata_store)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import Telemetry
from .trace_protocol import serialize_value, trace_protocol
from .tracing import (
CURRENT_TRACE_CONTEXT,
ROOT_SPAN_MARKERS,
end_trace,
enqueue_event,
get_current_span,
setup_logger,
span,
start_trace,
)
__all__ = [
"Telemetry",
"trace_protocol",
"serialize_value",
"CURRENT_TRACE_CONTEXT",
"ROOT_SPAN_MARKERS",
"end_trace",
"enqueue_event",
"get_current_span",
"setup_logger",
"span",
"start_trace",
]

View file

@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import threading
from typing import Any
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
UnstructuredLogEvent,
)
from llama_stack.apis.telemetry import (
Telemetry as TelemetryBase,
)
from llama_stack.core.telemetry.tracing import ROOT_SPAN_MARKERS
from llama_stack.log import get_logger
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry")
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class Telemetry(TelemetryBase):
def __init__(self) -> None:
self.meter = None
global _TRACER_PROVIDER
# Initialize the correct span processor based on the provider state.
# This is needed since once the span processor is set, it cannot be unset.
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
# Since the library client can be recreated multiple times in a notebook,
# the kernel will hold on to the span processor and cause duplicate spans to be written.
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if _TRACER_PROVIDER is None:
provider = TracerProvider()
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
# Use single OTLP endpoint for all telemetry signals
# Let OpenTelemetry SDK handle endpoint construction automatically
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
self.is_otel_endpoint_set = True
else:
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
self.is_otel_endpoint_set = False
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type.value,
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
else:
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Add metric as an event to the current span
try:
with self._lock:
# Only try to add to span if we have a valid span_id
if event.span_id:
try:
span_id = int(event.span_id, 16)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=f"metric.{event.metric}",
attributes={
"value": event.value,
"unit": event.unit,
**(event.attributes or {}),
},
timestamp=timestamp_ns,
)
except (ValueError, KeyError):
# Invalid span_id or span not found, but we already logged to console above
pass
except Exception:
# Lock acquisition failed
logger.debug("Failed to acquire lock to add metric to span")
# Log to OpenTelemetry meter if available
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = int(event.span_id, 16)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
context = None
if event.payload.parent_span_id:
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
elif traceparent:
carrier = {
"traceparent": traceparent,
"tracestate": tracestate,
}
context = TraceContextTextMapPropagator().extract(carrier=carrier)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")

View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, cast
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
def serialize_value(value: Any) -> str:
return str(_prepare_for_json(value))
def _prepare_for_json(value: Any) -> JSONValue:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
elif isinstance(value, str | int | float | bool):
return value
elif hasattr(value, "_name_"):
return cast(str, value._name_)
elif isinstance(value, BaseModel):
return cast(JSONValue, json.loads(value.model_dump_json()))
elif isinstance(value, list | tuple | set):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
else:
try:
json.dumps(value)
return cast(JSONValue, value)
except Exception:
return str(value)
def trace_protocol[T: type[Any]](cls: T) -> T:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args: dict[str, str] = {}
for i, arg in enumerate(args):
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes: dict[str, Primitive] = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": json.dumps(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
count = 0
try:
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.core.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
if original_init_subclass:
cast(Callable[..., None], original_init_subclass)(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls_any = cast(Any, cls)
cls_any.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -0,0 +1,388 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import contextvars
import logging # allow-direct-logging
import queue
import secrets
import sys
import threading
import time
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Any, Self
from llama_stack.apis.telemetry import (
Event,
LogSeverity,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.core.telemetry.trace_protocol import serialize_value
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
if not _fallback_logger.handlers:
_fallback_logger.propagate = False
_fallback_logger.setLevel(logging.ERROR)
_fallback_handler = logging.StreamHandler(sys.stderr)
_fallback_handler.setLevel(logging.ERROR)
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
_fallback_logger.addHandler(_fallback_handler)
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
# The logical root span may not be visible to this process if a parent context
# is passed in. The local root span is the first local span in a trace.
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def span_id_to_str(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = secrets.randbits(64)
while span_id == INVALID_SPAN_ID:
span_id = secrets.randbits(64)
return span_id_to_str(span_id)
def generate_trace_id() -> str:
trace_id = secrets.randbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = secrets.randbits(128)
return trace_id_to_str(trace_id)
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0
def log_event(self, event: Event) -> None:
try:
self.log_queue.put_nowait(event)
except queue.Full:
# Aggregate drops and emit at most once per interval via fallback logger
self._dropped_since_last_notice += 1
current_time = time.time()
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
_fallback_logger.error(
"Log queue is full; dropped %d events since last notice",
self._dropped_since_last_notice,
)
self._last_queue_full_log_time = current_time
self._dropped_since_last_notice = 0
def _worker(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_logs())
async def _process_logs(self):
while True:
try:
event = self.log_queue.get()
await self.api.log_event(event)
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self) -> None:
self.log_queue.join()
BACKGROUND_LOGGER: BackgroundLogger | None = None
def enqueue_event(event: Event) -> None:
"""Enqueue a telemetry event to the background logger if available.
This provides a non-blocking path for routers and other hot paths to
submit telemetry without awaiting the Telemetry API, reducing contention
with the main event loop.
"""
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
BACKGROUND_LOGGER.log_event(event)
class TraceContext:
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
self.spans: list[Span] = []
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(UTC),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self) -> Span | None:
return self.spans[-1] if self.spans else None
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
"trace_context", default=None
)
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return None
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
# Mark this span as the root for the trace for now. The processing of
# traceparent context if supplied comes later and will result in the
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
# i.e. the root of the spans originating in this process as this is
# needed to ensure that we insert this 'local' root span's id into
# the trace record in sqlite store.
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
context.push_span(name, attributes)
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARN
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
span = context.get_current_span()
if span is None:
return
enqueue_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(UTC),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self) -> None:
pass
class SpanContextManager:
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
self.name = name
self.attributes = attributes
self.span: Span | None = None
def __enter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any) -> None:
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = serialize_value(value)
async def __aenter__(self) -> Self:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
return SpanContextManager(name, attributes)
def get_current_span() -> Span | None:
global CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None
def is_debug_mode() -> bool:
"""Check if test recording debug mode is enabled via LLAMA_STACK_TEST_DEBUG env var."""
return os.environ.get("LLAMA_STACK_TEST_DEBUG", "").lower() in ("1", "true", "yes")

View file

@ -0,0 +1,11 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.12-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -0,0 +1,50 @@
# (Experimental) LLama Stack UI
## Docker Setup
:warning: This is a work in progress.
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
```
llama stack list-deps together | xargs -L1 uv pip install
llama stack run together
```
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
```
```bash
llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
3. Start Streamlit UI
```bash
uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py
```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
def main():
# Evaluation pages
application_evaluation_page = st.Page(
"page/evaluations/app_eval.py",
title="Evaluations (Scoring)",
icon="📊",
default=False,
)
native_evaluation_page = st.Page(
"page/evaluations/native_eval.py",
title="Evaluations (Generation + Scoring)",
icon="📊",
default=False,
)
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",
icon="🔍",
default=False,
)
pg = st.navigation(
{
"Playground": [
chat_page,
rag_page,
tool_page,
application_evaluation_page,
native_evaluation_page,
],
"Inspect": [provider_page, resources_page],
},
expanded=False,
)
pg.run()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from llama_stack_client import LlamaStackClient
class LlamaStackApi:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
)
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = dict.fromkeys(scoring_function_ids)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None
def data_url_from_file(file) -> str:
file_content = file.getvalue()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type = file.type
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def benchmarks():
# Benchmarks Section
st.header("Benchmarks")
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
if len(benchmarks_info) > 0:
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
st.json(benchmarks_info[selected_benchmark], expanded=True)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
apis_providers_lst = llama_stack_api.client.providers.list()
api_to_providers = {}
for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}")
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers()

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from streamlit_option_menu import option_menu
from llama_stack.core.ui.page.distribution.datasets import datasets
from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
def resources_page():
options = [
"Models",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
icons=icons,
orientation="horizontal",
styles={
"nav-link": {
"font-size": "12px",
},
},
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":
models()
elif selected_resource == "Scoring Functions":
scoring_functions()
elif selected_resource == "Shields":
shields()
resources_page()

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def shields():
# Shields Section
st.header("Shields")
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = llama_stack_api.client.scoring_functions.list()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = llama_stack_api.client.models.list()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = llama_stack_api.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
application_evaluation_page()

View file

@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def select_benchmark_1():
# Select Benchmarks
st.subheader("1. Choose An Eval Task")
benchmarks = llama_stack_api.client.benchmarks.list()
benchmarks = {et.identifier: et for et in benchmarks}
benchmarks_names = list(benchmarks.keys())
selected_benchmark = st.selectbox(
"Choose an eval task.",
options=benchmarks_names,
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
)
with st.expander("View Eval Task"):
st.json(benchmarks[selected_benchmark], expanded=True)
st.session_state["selected_benchmark"] = selected_benchmark
st.session_state["benchmarks"] = benchmarks
if st.button("Confirm", key="confirm_1"):
st.session_state["selected_benchmark_1_next"] = True
def define_eval_candidate_2():
if not st.session_state.get("selected_benchmark_1_next", None):
return
st.subheader("2. Define Eval Candidate")
st.info(
"""
Define the configurations for the evaluation candidate model or agent used for generation.
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
"""
)
with st.expander("Define Eval Candidate", expanded=True):
# Define Eval Candidate
candidate_type = st.radio("Candidate Type", ["model", "agent"])
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
# Sampling Parameters
st.markdown("##### Sampling Parameters")
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
if candidate_type == "model":
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
eval_candidate = {
"type": "model",
"model": selected_model,
"sampling_params": {
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
}
elif candidate_type == "agent":
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
tools_json = st.text_area(
"Tools Configuration (JSON)",
value=json.dumps(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": "ENTER_BRAVE_API_KEY_HERE",
}
]
),
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
height=200,
)
try:
tools = json.loads(tools_json)
except json.JSONDecodeError:
st.error("Invalid JSON format for tools configuration")
tools = []
eval_candidate = {
"type": "agent",
"config": {
"model": selected_model,
"instructions": system_prompt,
"tools": tools,
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False,
},
}
st.session_state["eval_candidate"] = eval_candidate
if st.button("Confirm", key="confirm_2"):
st.session_state["selected_eval_candidate_2_next"] = True
def run_evaluation_3():
if not st.session_state.get("selected_eval_candidate_2_next", None):
return
st.subheader("3. Run Evaluation")
# Add info box to explain configurations being used
st.info(
"""
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
"""
)
selected_benchmark = st.session_state["selected_benchmark"]
benchmarks = st.session_state["benchmarks"]
eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id,
)
total_rows = len(rows.data)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
min_value=1,
max_value=total_rows,
value=5,
help="Number of examples from the dataset to evaluate. ",
)
benchmark_config = {
"type": "benchmark",
"eval_candidate": eval_candidate,
"scoring_params": {},
}
with st.expander("View Evaluation Task", expanded=True):
st.json(benchmarks[selected_benchmark], expanded=True)
with st.expander("View Evaluation Task Configuration", expanded=True):
st.json(benchmark_config, expanded=True)
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.data
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
eval_res = llama_stack_api.client.eval.evaluate_rows(
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
benchmark_config=benchmark_config,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for k in eval_res.generations[0].keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")
select_benchmark_1()
define_eval_candidate_2()
run_evaluation_3()
native_evaluation_page()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
# Sidebar configurations
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
index=0,
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
)
stream = st.checkbox("Stream", value=True)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful AI assistant.",
help="Initial instructions given to the AI to set its behavior and context",
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Main chat interface
st.title("🦙 Chat")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Example: What is Llama Stack?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
response = llama_stack_api.client.inference.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
model_id=selected_model,
stream=stream,
sampling_params={
"strategy": strategy,
"max_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
},
)
if stream:
for chunk in response:
if chunk.event.event_type == "progress":
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -0,0 +1,352 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
import json
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack_client.lib.agents.react.agent import ReActAgent
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
from llama_stack.core.ui.modules.api import llama_stack_api
class AgentType(enum.Enum):
REGULAR = "Regular"
REACT = "ReAct"
def tool_chat_page():
st.title("🛠 Tools")
client = llama_stack_api.client
models = client.models.list()
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
tool_groups = client.toolgroups.list()
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_stores = []
def reset_agent():
st.session_state.clear()
st.cache_resource.clear()
with st.sidebar:
st.title("Configuration")
st.subheader("Model")
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
st.subheader("Available ToolGroups")
toolgroup_selection = st.pills(
label="Built-in tools",
options=builtin_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of built-in tools from your llama stack server.",
)
if "builtin::rag" in toolgroup_selection:
vector_stores = llama_stack_api.client.vector_stores.list() or []
if not vector_stores:
st.info("No vector databases available for selection.")
vector_stores = [vector_store.identifier for vector_store in vector_stores]
selected_vector_stores = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_stores,
on_change=reset_agent,
)
mcp_selection = st.pills(
label="MCP Servers",
options=mcp_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of MCP servers registered to your llama stack server.",
)
toolgroup_selection.extend(mcp_selection)
grouped_tools = {}
total_tools = 0
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")
for group_id, tools in grouped_tools.items():
with st.expander(f"🔧 Tools from `{group_id}`"):
for idx, tool in enumerate(tools, start=1):
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
label="Select Agent Type",
options=["Regular", "ReAct"],
on_change=reset_agent,
)
if agent_type == "ReAct":
agent_type = AgentType.REACT
else:
agent_type = AgentType.REGULAR
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=64,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
for i, tool_name in enumerate(toolgroup_selection):
if tool_name == "builtin::rag":
tool_dict = dict(
name="builtin::rag",
args={
"vector_store_ids": list(selected_vector_stores),
},
)
toolgroup_selection[i] = tool_dict
@st.cache_resource
def create_agent():
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input(placeholder=""):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
def response_generator(turn_response):
if st.session_state.get("agent_type") == AgentType.REACT:
return _handle_react_response(turn_response)
else:
return _handle_regular_response(turn_response)
def _handle_react_response(turn_response):
current_step_content = ""
final_answer = None
tool_results = []
for response in turn_response:
if not hasattr(response.event, "payload"):
yield (
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
"The response received is missing an expected `payload` attribute.\n"
"This could indicate a malformed response or an internal issue within the server.\n\n"
f"Error details: {response}"
)
return
payload = response.event.payload
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
current_step_content += payload.delta.text
continue
if payload.event_type == "step_complete":
step_details = payload.step_details
if step_details.step_type == "inference":
yield from _process_inference_step(current_step_content, tool_results, final_answer)
current_step_content = ""
elif step_details.step_type == "tool_execution":
tool_results = _process_tool_execution(step_details, tool_results)
current_step_content = ""
else:
current_step_content = ""
if not final_answer and tool_results:
yield from _format_tool_results_summary(tool_results)
def _process_inference_step(current_step_content, tool_results, final_answer):
try:
react_output_data = json.loads(current_step_content)
thought = react_output_data.get("thought")
action = react_output_data.get("action")
answer = react_output_data.get("answer")
if answer and answer != "null" and answer is not None:
final_answer = answer
if thought:
with st.expander("🤔 Thinking...", expanded=False):
st.markdown(f":grey[__{thought}__]")
if action and isinstance(action, dict):
tool_name = action.get("tool_name")
tool_params = action.get("tool_params")
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
st.json(tool_params)
if answer and answer != "null" and answer is not None:
yield f"\n\n✅ **Final Answer:**\n{answer}"
except json.JSONDecodeError:
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
except Exception as e:
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
return final_answer
def _process_tool_execution(step_details, tool_results):
try:
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
for tool_response in step_details.tool_responses:
tool_name = tool_response.tool_name
content = tool_response.content
tool_results.append((tool_name, content))
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
try:
parsed_content = json.loads(content)
st.json(parsed_content)
except json.JSONDecodeError:
st.code(content, language=None)
else:
with st.expander("⚙️ Observation", expanded=False):
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
except Exception as e:
with st.expander("⚙️ Error in Tool Execution", expanded=False):
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
return tool_results
def _format_tool_results_summary(tool_results):
yield "\n\n**Here's what I found:**\n"
for tool_name, content in tool_results:
try:
parsed_content = json.loads(content)
if tool_name == "web_search" and "top_k" in parsed_content:
yield from _format_web_search_results(parsed_content)
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
yield from _format_results_list(parsed_content["results"])
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
yield from _format_dict_results(parsed_content)
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
yield from _format_list_results(parsed_content)
except json.JSONDecodeError:
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
except (TypeError, AttributeError, KeyError, IndexError) as e:
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
def _format_web_search_results(parsed_content):
for i, result in enumerate(parsed_content["top_k"], 1):
if i <= 3:
title = result.get("title", "Untitled")
url = result.get("url", "")
content_text = result.get("content", "").strip()
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
def _format_results_list(results):
for i, result in enumerate(results, 1):
if i <= 3:
if isinstance(result, dict):
name = result.get("name", result.get("title", "Result " + str(i)))
description = result.get("description", result.get("content", result.get("summary", "")))
yield f"\n- **{name}**\n {description}\n"
else:
yield f"\n- {result}\n"
def _format_dict_results(parsed_content):
yield "\n```\n"
for key, value in list(parsed_content.items())[:5]:
if isinstance(value, str) and len(value) < 100:
yield f"{key}: {value}\n"
else:
yield f"{key}: [Complex data]\n"
yield "```\n"
def _format_list_results(parsed_content):
yield "\n"
for _, item in enumerate(parsed_content[:3], 1):
if isinstance(item, str):
yield f"- {item}\n"
elif isinstance(item, dict) and "text" in item:
yield f"- {item['text']}\n"
elif isinstance(item, dict) and len(item) > 0:
first_value = next(iter(item.values()))
if isinstance(first_value, str) and len(first_value) < 100:
yield f"- {first_value}\n"
def _handle_regular_response(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
if response.event.payload.event_type == "step_progress":
if hasattr(response.event.payload.delta, "text"):
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
if response.event.payload.step_details.tool_calls:
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
else:
yield "No tool_calls present in step_details"
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response_content = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response_content})
tool_chat_page()

View file

@ -0,0 +1,5 @@
llama-stack>=0.2.1
llama-stack-client>=0.2.1
pandas
streamlit
streamlit-option-menu

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
elif isinstance(v, list):
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)
return result
return _redact_dict(data)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
class Mode(StrEnum):
RUN = "run"
BUILD = "build"
def resolve_config_or_distro(
config_or_distro: str,
mode: Mode = Mode.RUN,
) -> Path:
"""
Resolve a config/distro argument to a concrete config file path.
Args:
config_or_distro: User input (file path, distribution name, or built distribution)
mode: Mode resolving for ("run", "build", "server")
Returns:
Path to the resolved config file
Raises:
ValueError: If resolution fails
"""
# Strategy 1: Try as file path first
config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file():
logger.debug(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists():
logger.debug(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_distro, mode))
def _get_distro_config_path(distro_name: str, mode: Mode) -> Path:
"""Get the config file path for a distro."""
return DISTRO_DIR / distro_name / f"{mode}.yaml"
def _format_resolution_error(config_or_distro: str, mode: Mode) -> str:
"""Format a helpful error message for resolution failures."""
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
distro_path = _get_distro_config_path(config_or_distro, mode)
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
available_distros = _get_available_distros()
distros_str = ", ".join(available_distros) if available_distros else "none found"
return f"""Could not resolve config or distribution '{config_or_distro}'.
Tried the following locations:
1. As file path: {Path(config_or_distro).resolve()}
2. As distribution: {distro_path}
3. As built distribution: ({distrib_path}, {distrib_path2})
Available distributions: {distros_str}
Did you mean one of these distributions?
{_format_distro_suggestions(available_distros, config_or_distro)}
"""
def _get_available_distros() -> list[str]:
"""Get list of available distro names."""
if not DISTRO_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
return []
return list(
set(
[d.name for d in DISTRO_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
)
)
def _format_distro_suggestions(distros: list[str], user_input: str) -> str:
"""Format distro suggestions for error messages, showing closest matches first."""
if not distros:
return " (no distros found)"
import difflib
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
close_matches = difflib.get_close_matches(user_input, distros, n=3, cutoff=0.3)
display_distros = close_matches if close_matches else distros[:3]
suggestions = [f" - {d}" for d in display_distros]
return "\n".join(suggestions)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextvars import ContextVar
def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
# Restore context values before any await
for context_var in context_vars:
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
# Update our tracked values with any changes made during this iteration
for context_var in context_vars:
initial_context_values[context_var.name] = context_var.get()
yield item
except StopAsyncIteration:
break
return wrapper()

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import os
import signal
import subprocess
import sys
from termcolor import cprint
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list:
# Only venv is supported now
current_venv = os.environ.get("VIRTUAL_ENV")
env_name = image_name or current_venv
if not env_name:
cprint(
"No current virtual environment detected, please specify a virtual environment name with --image-name",
color="red",
file=sys.stderr,
)
return []
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
script = importlib.resources.files("llama_stack") / "core/start_stack.sh"
run_args = [
script,
image_type,
env_name,
]
return run_args
def in_notebook():
try:
from IPython import get_ipython
ipython = get_ipython()
if ipython is None or "IPKernelApp" not in ipython.config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def run_command(command: list[str]) -> int:
"""
Run a command with interrupt handling and output capture.
Uses subprocess.run with direct stream piping for better performance.
Args:
command (list): The command to run.
Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
log.info("\nCtrl-C detected. Aborting...")
try:
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
# Run the command with stdout/stderr piped directly to system streams
result = subprocess.run(
command,
text=True,
check=False,
)
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")
return 1
except Exception as e:
log.exception(f"Unexpected error: {e}")
return 1
finally:
# Restore the original signal handler
signal.signal(signal.SIGINT, original_sigint)

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
class LlamaStackImageType(enum.Enum):
CONTAINER = "container"
VENV = "venv"

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))

View file

@ -0,0 +1,283 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is list or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
return False
def is_basemodel_without_fields(typ):
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
def can_recurse(typ):
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items():
if field_name in validator.info.fields:
validator.func(value)
return value
def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if get_origin(typ) is not Annotated:
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
def prompt_for_discriminated_union(
field_name,
typ,
existing_value,
):
if isinstance(typ, FieldInfo):
inner_type = typ.annotation
discriminator = typ.discriminator
default_value = typ.default
else:
args = get_args(typ)
inner_type = args[0]
discriminator = args[1].discriminator
default_value = args[1].default
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
if default_value is not None:
prompt += f" (default: {default_value})"
discriminator_value = input(f"{prompt}: ")
if discriminator_value == "" and default_value is not None:
discriminator_value = default_value
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
# We should add handling for the most common cases to tide us over.
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = getattr(existing_config, field_name) if existing_config else None
if existing_value:
default_value = existing_value
else:
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
is_required = field.is_required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
# this branch does not handle existing and default values yet
user_input = input(prompt + " ")
try:
value = field_type[user_input]
validated_value = manually_validate_field(config_type, field, value)
config_data[field_name] = validated_value
break
except KeyError:
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
continue
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
config_data[field_name] = prompt_for_discriminated_union(
field_name,
nested_type,
existing_value,
)
elif can_recurse(field_type):
log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:
log.error("This field is required. Please provide a value.")
continue
else:
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
value = None
else:
field_type = get_non_none_type(field_type)
value = user_input
# Handle List of primitives
elif is_list_of_primitives(field_type):
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
value = [element_type(item) for item in value]
except json.JSONDecodeError:
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
continue
except ValueError as e:
log.error(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError("Input must be a JSON-encoded dictionary")
except json.JSONDecodeError:
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
# For nested BaseModels, we assume a dictionary-like string input
import ast
value = field_type(**ast.literal_eval(user_input))
else:
value = field_type(user_input)
except ValueError:
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
continue
try:
# Validate the field using our manual validation function
validated_value = manually_validate_field(config_type, field_name, value)
config_data[field_name] = validated_value
break
except ValueError as e:
log.error(f"Validation error: {str(e)}")
return config_type(**config_data)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from enum import Enum
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
elif isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)