From 63a9f08c9ee3fb591d99a8975ade915038bf4c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 28 May 2025 18:53:33 +0200 Subject: [PATCH] chore: use starlette built-in Route class (#2267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Use a more common pattern and known terminology from the ecosystem, where Route is more approved than Endpoint. Signed-off-by: Sébastien Han --- llama_stack/distribution/inspect.py | 12 +-- llama_stack/distribution/library_client.py | 26 ++++--- .../server/{endpoints.py => routes.py} | 78 +++++++++++-------- llama_stack/distribution/server/server.py | 53 +++++++------ pyproject.toml | 2 + requirements.txt | 28 +++++++ uv.lock | 4 + 7 files changed, 131 insertions(+), 72 deletions(-) rename llama_stack/distribution/server/{endpoints.py => routes.py} (55%) diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 3321ec291..5822070ad 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -16,7 +16,7 @@ from llama_stack.apis.inspect import ( VersionInfo, ) from llama_stack.distribution.datatypes import StackRunConfig -from llama_stack.distribution.server.endpoints import get_all_api_endpoints +from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.providers.datatypes import HealthStatus @@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config ret = [] - all_endpoints = get_all_api_endpoints() + all_endpoints = get_all_api_routes() for api, endpoints in all_endpoints.items(): # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: ret.extend( [ RouteInfo( - route=e.route, - method=e.method, + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[], # These APIs don't have "real" providers - they're internal to the stack ) for e in endpoints @@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect): ret.extend( [ RouteInfo( - route=e.route, - method=e.method, + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[p.provider_type for p in providers], ) for e in endpoints diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 3cd2d1728..f32130cf9 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import ( request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry -from llama_stack.distribution.server.endpoints import ( - find_matching_endpoint, - initialize_endpoint_impls, -) +from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls from llama_stack.distribution.stack import ( construct_stack, get_stack_run_config_from_template, @@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): async def initialize(self) -> bool: try: - self.endpoint_impls = None + self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) except ModuleNotFoundError as _e: cprint(_e.msg, color="red", file=sys.stderr) @@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): safe_config = redact_sensitive_fields(self.config.model_dump()) console.print(yaml.dump(safe_config, indent=2)) - self.endpoint_impls = initialize_endpoint_impls(self.impls) + self.route_impls = initialize_route_impls(self.impls) return True async def request( @@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): stream=False, stream_cls=None, ): - if not self.endpoint_impls: + if not self.route_impls: raise ValueError("Client not initialized") # Create headers with provider data if available @@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to: Any, options: Any, ): + if self.route_impls is None: + raise ValueError("Client not initialized") + path = options.url body = options.params or {} body |= options.json_data or {} - matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) + matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) body |= path_params body = self._convert_body(path, options.method, body) await start_trace(route, {"__location__": "library_client"}) @@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): options: Any, stream_cls: Any, ): + if self.route_impls is None: + raise ValueError("Client not initialized") + path = options.url body = options.params or {} body |= options.json_data or {} - func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) + func, path_params, route = find_matching_route(options.method, path, self.route_impls) body |= path_params body = self._convert_body(path, options.method, body) @@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not body: return {} - func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls) + if self.route_impls is None: + raise ValueError("Client not initialized") + + func, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/routes.py similarity index 55% rename from llama_stack/distribution/server/endpoints.py rename to llama_stack/distribution/server/routes.py index ec1f7e083..ea66fec5a 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/routes.py @@ -6,20 +6,23 @@ import inspect import re +from collections.abc import Callable +from typing import Any -from pydantic import BaseModel +from aiohttp import hdrs +from starlette.routing import Route from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map from llama_stack.providers.datatypes import Api - -class ApiEndpoint(BaseModel): - route: str - method: str - name: str - descriptive_name: str | None = None +EndpointFunc = Callable[..., Any] +PathParams = dict[str, str] +RouteInfo = tuple[EndpointFunc, str] +PathImpl = dict[str, RouteInfo] +RouteImpls = dict[str, PathImpl] +RouteMatch = tuple[EndpointFunc, PathParams, str] def toolgroup_protocol_map(): @@ -28,13 +31,13 @@ def toolgroup_protocol_map(): } -def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: +def get_all_api_routes() -> dict[Api, list[Route]]: apis = {} protocols = api_protocol_map() toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): - endpoints = [] + routes = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) # HACK ALERT @@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: if not hasattr(method, "__webmethod__"): continue - webmethod = method.__webmethod__ - route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" - if webmethod.method == "GET": - method = "get" - elif webmethod.method == "DELETE": - method = "delete" + # The __webmethod__ attribute is dynamically added by the @webmethod decorator + # mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error + webmethod = method.__webmethod__ # type: ignore[attr-defined] + path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" + if webmethod.method == hdrs.METH_GET: + http_method = hdrs.METH_GET + elif webmethod.method == hdrs.METH_DELETE: + http_method = hdrs.METH_DELETE else: - method = "post" - endpoints.append( - ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name) - ) + http_method = hdrs.METH_POST + routes.append( + Route(path=path, methods=[http_method], name=name, endpoint=None) + ) # setting endpoint to None since don't use a Router object - apis[api] = endpoints + apis[api] = routes return apis -def initialize_endpoint_impls(impls): - endpoints = get_all_api_endpoints() - endpoint_impls = {} +def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: + routes = get_all_api_routes() + route_impls: RouteImpls = {} def _convert_path_to_regex(path: str) -> str: # Convert {param} to named capture groups @@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls): return f"^{pattern}$" - for api, api_endpoints in endpoints.items(): + for api, api_routes in routes.items(): if api not in impls: continue - for endpoint in api_endpoints: + for route in api_routes: impl = impls[api] - func = getattr(impl, endpoint.name) - if endpoint.method not in endpoint_impls: - endpoint_impls[endpoint.method] = {} - endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = ( + func = getattr(impl, route.name) + # Get the first (and typically only) method from the set, filtering out HEAD + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + continue # Skip if only HEAD method is available + method = available_methods[0].lower() + if method not in route_impls: + route_impls[method] = {} + route_impls[method][_convert_path_to_regex(route.path)] = ( func, - endpoint.descriptive_name or endpoint.route, + route.path, ) - return endpoint_impls + return route_impls -def find_matching_endpoint(method, path, endpoint_impls): +def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch: """Find the matching endpoint implementation for a given method and path. Args: method: HTTP method (GET, POST, etc.) path: URL path to match against - endpoint_impls: A dictionary of endpoint implementations + route_impls: A dictionary of endpoint implementations Returns: A tuple of (endpoint_function, path_params, descriptive_name) @@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls): Raises: ValueError: If no matching endpoint is found """ - impls = endpoint_impls.get(method.lower()) + impls = route_impls.get(method.lower()) if not impls: raise ValueError(f"No endpoint found for {path}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index d70f06691..6c88bbfe9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -6,6 +6,7 @@ import argparse import asyncio +import functools import inspect import json import os @@ -13,6 +14,7 @@ import ssl import sys import traceback import warnings +from collections.abc import Callable from contextlib import asynccontextmanager from importlib.metadata import version as parse_version from pathlib import Path @@ -20,6 +22,7 @@ from typing import Annotated, Any import rich.pretty import yaml +from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError @@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import ( request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError -from llama_stack.distribution.server.endpoints import ( - find_matching_endpoint, - initialize_endpoint_impls, +from llama_stack.distribution.server.routes import ( + find_matching_route, + get_all_api_routes, + initialize_route_impls, ) from llama_stack.distribution.stack import ( construct_stack, @@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from .auth import AuthenticationMiddleware -from .endpoints import get_all_api_endpoints from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request): logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") -def create_dynamic_typed_route(func: Any, method: str, route: str): - async def endpoint(request: Request, **kwargs): +def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: + @functools.wraps(func) + async def route_handler(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) @@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): for param in new_params[1:] ] - endpoint.__signature__ = sig.replace(parameters=new_params) + route_handler.__signature__ = sig.replace(parameters=new_params) - return endpoint + return route_handler class TracingMiddleware: @@ -274,14 +278,14 @@ class TracingMiddleware: logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) - if not hasattr(self, "endpoint_impls"): - self.endpoint_impls = initialize_endpoint_impls(self.impls) + if not hasattr(self, "route_impls"): + self.route_impls = initialize_route_impls(self.impls) try: - _, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls) + _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) except ValueError: # If no matching endpoint is found, pass through to FastAPI - logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") + logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) trace_attributes = {"__location__": "server", "raw_path": path} @@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None): else: setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - all_endpoints = get_all_api_endpoints() + all_routes = get_all_api_routes() if config.apis: apis_to_serve = set(config.apis) @@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None): for api_str in apis_to_serve: api = Api(api_str) - endpoints = all_endpoints[api] + routes = all_routes[api] impl = impls[api] - for endpoint in endpoints: - if not hasattr(impl, endpoint.name): + for route in routes: + if not hasattr(impl, route.name): # ideally this should be a typing violation already - raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") + raise ValueError(f"Could not find method {route.name} on {impl}!") - impl_method = getattr(impl, endpoint.name) - logger.debug(f"{endpoint.method.upper()} {endpoint.route}") + impl_method = getattr(impl, route.name) + # Filter out HEAD method since it's automatically handled by FastAPI for GET routes + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + raise ValueError(f"No methods found for {route.name} on {impl}") + method = available_methods[0] + logger.debug(f"{method} {route.path}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, endpoint.method)(endpoint.route, response_model=None)( + getattr(app, method.lower())(route.path, response_model=None)( create_dynamic_typed_route( impl_method, - endpoint.method, - endpoint.route, + method.lower(), + route.path, ) ) diff --git a/pyproject.toml b/pyproject.toml index 043149b40..2bb6292aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", ] dependencies = [ + "aiohttp", "fire", "httpx", "huggingface-hub", @@ -35,6 +36,7 @@ dependencies = [ "requests", "rich", "setuptools", + "starlette", "termcolor", "tiktoken", "pillow", diff --git a/requirements.txt b/requirements.txt index 0b77355d3..0c079a855 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,11 @@ # This file was autogenerated by uv via the following command: # uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt +aiohappyeyeballs==2.5.0 + # via aiohttp +aiohttp==3.11.13 + # via llama-stack +aiosignal==1.3.2 + # via aiohttp annotated-types==0.7.0 # via pydantic anyio==4.8.0 @@ -7,8 +13,12 @@ anyio==4.8.0 # httpx # llama-stack-client # openai + # starlette +async-timeout==5.0.1 ; python_full_version < '3.11' + # via aiohttp attrs==25.1.0 # via + # aiohttp # jsonschema # referencing certifi==2025.1.31 @@ -36,6 +46,10 @@ filelock==3.17.0 # via huggingface-hub fire==0.7.0 # via llama-stack +frozenlist==1.5.0 + # via + # aiohttp + # aiosignal fsspec==2024.12.0 # via huggingface-hub h11==0.16.0 @@ -56,6 +70,7 @@ idna==3.10 # anyio # httpx # requests + # yarl jinja2==3.1.6 # via llama-stack jiter==0.8.2 @@ -72,6 +87,10 @@ markupsafe==3.0.2 # via jinja2 mdurl==0.1.2 # via markdown-it-py +multidict==6.1.0 + # via + # aiohttp + # yarl numpy==2.2.3 # via pandas openai==1.71.0 @@ -86,6 +105,10 @@ prompt-toolkit==3.0.50 # via # llama-stack # llama-stack-client +propcache==0.3.0 + # via + # aiohttp + # yarl pyaml==25.1.0 # via llama-stack-client pyasn1==0.4.8 @@ -145,6 +168,8 @@ sniffio==1.3.1 # anyio # llama-stack-client # openai +starlette==0.45.3 + # via llama-stack termcolor==2.5.0 # via # fire @@ -162,6 +187,7 @@ typing-extensions==4.12.2 # anyio # huggingface-hub # llama-stack-client + # multidict # openai # pydantic # pydantic-core @@ -173,3 +199,5 @@ urllib3==2.3.0 # via requests wcwidth==0.2.13 # via prompt-toolkit +yarl==1.18.3 + # via aiohttp diff --git a/uv.lock b/uv.lock index e5168d5fe..dae04b5f6 100644 --- a/uv.lock +++ b/uv.lock @@ -1456,6 +1456,7 @@ name = "llama-stack" version = "0.2.8" source = { editable = "." } dependencies = [ + { name = "aiohttp" }, { name = "fire" }, { name = "h11" }, { name = "httpx" }, @@ -1472,6 +1473,7 @@ dependencies = [ { name = "requests" }, { name = "rich" }, { name = "setuptools" }, + { name = "starlette" }, { name = "termcolor" }, { name = "tiktoken" }, ] @@ -1557,6 +1559,7 @@ unit = [ [package.metadata] requires-dist = [ + { name = "aiohttp" }, { name = "fire" }, { name = "h11", specifier = ">=0.16.0" }, { name = "httpx" }, @@ -1575,6 +1578,7 @@ requires-dist = [ { name = "requests" }, { name = "rich" }, { name = "setuptools" }, + { name = "starlette" }, { name = "streamlit", marker = "extra == 'ui'" }, { name = "streamlit-option-menu", marker = "extra == 'ui'" }, { name = "termcolor" },