chore: use starlette built-in Route class (#2267)

# What does this PR do?

Use a more common pattern and known terminology from the ecosystem,
where Route is more approved than Endpoint.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-28 18:53:33 +02:00 committed by GitHub
parent 56e5ddb39f
commit 63a9f08c9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 131 additions and 72 deletions

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo, VersionInfo,
) )
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_routes()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config # Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:
ret.extend( ret.extend(
[ [
RouteInfo( RouteInfo(
route=e.route, route=e.path,
method=e.method, method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
) )
for e in endpoints for e in endpoints
@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect):
ret.extend( ret.extend(
[ [
RouteInfo( RouteInfo(
route=e.route, route=e.path,
method=e.method, method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e in endpoints

View file

@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self) -> bool: async def initialize(self) -> bool:
try: try:
self.endpoint_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump()) safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2)) console.print(yaml.dump(safe_config, indent=2))
self.endpoint_impls = initialize_endpoint_impls(self.impls) self.route_impls = initialize_route_impls(self.impls)
return True return True
async def request( async def request(
@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
stream=False, stream=False,
stream_cls=None, stream_cls=None,
): ):
if not self.endpoint_impls: if not self.route_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
# Create headers with provider data if available # Create headers with provider data if available
@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to: Any, cast_to: Any,
options: Any, options: Any,
): ):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any, options: Any,
stream_cls: Any, stream_cls: Any,
): ):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body: if not body:
return {} return {}
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls) if self.route_impls is None:
raise ValueError("Client not initialized")
func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -6,20 +6,23 @@
import inspect import inspect
import re import re
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
EndpointFunc = Callable[..., Any]
class ApiEndpoint(BaseModel): PathParams = dict[str, str]
route: str RouteInfo = tuple[EndpointFunc, str]
method: str PathImpl = dict[str, RouteInfo]
name: str RouteImpls = dict[str, PathImpl]
descriptive_name: str | None = None RouteMatch = tuple[EndpointFunc, PathParams, str]
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
} }
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: def get_all_api_routes() -> dict[Api, list[Route]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map() toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT # HACK ALERT
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
if not hasattr(method, "__webmethod__"): if not hasattr(method, "__webmethod__"):
continue continue
webmethod = method.__webmethod__ # The __webmethod__ attribute is dynamically added by the @webmethod decorator
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" # mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
if webmethod.method == "GET": webmethod = method.__webmethod__ # type: ignore[attr-defined]
method = "get" path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
elif webmethod.method == "DELETE": if webmethod.method == hdrs.METH_GET:
method = "delete" http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else: else:
method = "post" http_method = hdrs.METH_POST
endpoints.append( routes.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name) Route(path=path, methods=[http_method], name=name, endpoint=None)
) ) # setting endpoint to None since don't use a Router object
apis[api] = endpoints apis[api] = routes
return apis return apis
def initialize_endpoint_impls(impls): def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
endpoints = get_all_api_endpoints() routes = get_all_api_routes()
endpoint_impls = {} route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str: def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups # Convert {param} to named capture groups
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
return f"^{pattern}$" return f"^{pattern}$"
for api, api_endpoints in endpoints.items(): for api, api_routes in routes.items():
if api not in impls: if api not in impls:
continue continue
for endpoint in api_endpoints: for route in api_routes:
impl = impls[api] impl = impls[api]
func = getattr(impl, endpoint.name) func = getattr(impl, route.name)
if endpoint.method not in endpoint_impls: # Get the first (and typically only) method from the set, filtering out HEAD
endpoint_impls[endpoint.method] = {} available_methods = [m for m in route.methods if m != "HEAD"]
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = ( if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func, func,
endpoint.descriptive_name or endpoint.route, route.path,
) )
return endpoint_impls return route_impls
def find_matching_endpoint(method, path, endpoint_impls): def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path. """Find the matching endpoint implementation for a given method and path.
Args: Args:
method: HTTP method (GET, POST, etc.) method: HTTP method (GET, POST, etc.)
path: URL path to match against path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations route_impls: A dictionary of endpoint implementations
Returns: Returns:
A tuple of (endpoint_function, path_params, descriptive_name) A tuple of (endpoint_function, path_params, descriptive_name)
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
Raises: Raises:
ValueError: If no matching endpoint is found ValueError: If no matching endpoint is found
""" """
impls = endpoint_impls.get(method.lower()) impls = route_impls.get(method.lower())
if not impls: if not impls:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import os import os
@ -13,6 +14,7 @@ import ssl
import sys import sys
import traceback import traceback
import warnings import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
@ -20,6 +22,7 @@ from typing import Annotated, Any
import rich.pretty import rich.pretty
import yaml import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.routes import (
find_matching_endpoint, find_matching_route,
initialize_endpoint_impls, get_all_api_routes,
initialize_route_impls,
) )
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
) )
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request):
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
async def endpoint(request: Request, **kwargs): @functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
for param in new_params[1:] for param in new_params[1:]
] ]
endpoint.__signature__ = sig.replace(parameters=new_params) route_handler.__signature__ = sig.replace(parameters=new_params)
return endpoint return route_handler
class TracingMiddleware: class TracingMiddleware:
@ -274,14 +278,14 @@ class TracingMiddleware:
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"): if not hasattr(self, "route_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls) self.route_impls = initialize_route_impls(self.impls)
try: try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls) _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
except ValueError: except ValueError:
# If no matching endpoint is found, pass through to FastAPI # If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path} trace_attributes = {"__location__": "server", "raw_path": path}
@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {})) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints() all_routes = get_all_api_routes()
if config.apis: if config.apis:
apis_to_serve = set(config.apis) apis_to_serve = set(config.apis)
@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] routes = all_routes[api]
impl = impls[api] impl = impls[api]
for endpoint in endpoints: for route in routes:
if not hasattr(impl, endpoint.name): if not hasattr(impl, route.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, route.name)
logger.debug(f"{endpoint.method.upper()} {endpoint.route}") # Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, method.lower(),
endpoint.route, route.path,
) )
) )

View file

@ -21,6 +21,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Information Analysis",
] ]
dependencies = [ dependencies = [
"aiohttp",
"fire", "fire",
"httpx", "httpx",
"huggingface-hub", "huggingface-hub",
@ -35,6 +36,7 @@ dependencies = [
"requests", "requests",
"rich", "rich",
"setuptools", "setuptools",
"starlette",
"termcolor", "termcolor",
"tiktoken", "tiktoken",
"pillow", "pillow",

View file

@ -1,5 +1,11 @@
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt # uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
aiohappyeyeballs==2.5.0
# via aiohttp
aiohttp==3.11.13
# via llama-stack
aiosignal==1.3.2
# via aiohttp
annotated-types==0.7.0 annotated-types==0.7.0
# via pydantic # via pydantic
anyio==4.8.0 anyio==4.8.0
@ -7,8 +13,12 @@ anyio==4.8.0
# httpx # httpx
# llama-stack-client # llama-stack-client
# openai # openai
# starlette
async-timeout==5.0.1 ; python_full_version < '3.11'
# via aiohttp
attrs==25.1.0 attrs==25.1.0
# via # via
# aiohttp
# jsonschema # jsonschema
# referencing # referencing
certifi==2025.1.31 certifi==2025.1.31
@ -36,6 +46,10 @@ filelock==3.17.0
# via huggingface-hub # via huggingface-hub
fire==0.7.0 fire==0.7.0
# via llama-stack # via llama-stack
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.12.0 fsspec==2024.12.0
# via huggingface-hub # via huggingface-hub
h11==0.16.0 h11==0.16.0
@ -56,6 +70,7 @@ idna==3.10
# anyio # anyio
# httpx # httpx
# requests # requests
# yarl
jinja2==3.1.6 jinja2==3.1.6
# via llama-stack # via llama-stack
jiter==0.8.2 jiter==0.8.2
@ -72,6 +87,10 @@ markupsafe==3.0.2
# via jinja2 # via jinja2
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
multidict==6.1.0
# via
# aiohttp
# yarl
numpy==2.2.3 numpy==2.2.3
# via pandas # via pandas
openai==1.71.0 openai==1.71.0
@ -86,6 +105,10 @@ prompt-toolkit==3.0.50
# via # via
# llama-stack # llama-stack
# llama-stack-client # llama-stack-client
propcache==0.3.0
# via
# aiohttp
# yarl
pyaml==25.1.0 pyaml==25.1.0
# via llama-stack-client # via llama-stack-client
pyasn1==0.4.8 pyasn1==0.4.8
@ -145,6 +168,8 @@ sniffio==1.3.1
# anyio # anyio
# llama-stack-client # llama-stack-client
# openai # openai
starlette==0.45.3
# via llama-stack
termcolor==2.5.0 termcolor==2.5.0
# via # via
# fire # fire
@ -162,6 +187,7 @@ typing-extensions==4.12.2
# anyio # anyio
# huggingface-hub # huggingface-hub
# llama-stack-client # llama-stack-client
# multidict
# openai # openai
# pydantic # pydantic
# pydantic-core # pydantic-core
@ -173,3 +199,5 @@ urllib3==2.3.0
# via requests # via requests
wcwidth==0.2.13 wcwidth==0.2.13
# via prompt-toolkit # via prompt-toolkit
yarl==1.18.3
# via aiohttp

4
uv.lock generated
View file

@ -1456,6 +1456,7 @@ name = "llama-stack"
version = "0.2.8" version = "0.2.8"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiohttp" },
{ name = "fire" }, { name = "fire" },
{ name = "h11" }, { name = "h11" },
{ name = "httpx" }, { name = "httpx" },
@ -1472,6 +1473,7 @@ dependencies = [
{ name = "requests" }, { name = "requests" },
{ name = "rich" }, { name = "rich" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "starlette" },
{ name = "termcolor" }, { name = "termcolor" },
{ name = "tiktoken" }, { name = "tiktoken" },
] ]
@ -1557,6 +1559,7 @@ unit = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiohttp" },
{ name = "fire" }, { name = "fire" },
{ name = "h11", specifier = ">=0.16.0" }, { name = "h11", specifier = ">=0.16.0" },
{ name = "httpx" }, { name = "httpx" },
@ -1575,6 +1578,7 @@ requires-dist = [
{ name = "requests" }, { name = "requests" },
{ name = "rich" }, { name = "rich" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "starlette" },
{ name = "streamlit", marker = "extra == 'ui'" }, { name = "streamlit", marker = "extra == 'ui'" },
{ name = "streamlit-option-menu", marker = "extra == 'ui'" }, { name = "streamlit-option-menu", marker = "extra == 'ui'" },
{ name = "termcolor" }, { name = "termcolor" },