forked from phoenix-oss/llama-stack-mirror
chore: use starlette built-in Route class (#2267)
# What does this PR do? Use a more common pattern and known terminology from the ecosystem, where Route is more approved than Endpoint. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
56e5ddb39f
commit
63a9f08c9e
7 changed files with 131 additions and 72 deletions
|
@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
|
|||
VersionInfo,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
|
@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
all_endpoints = get_all_api_routes()
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e in endpoints
|
||||
|
@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect):
|
|||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
|
|
|
@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
|
@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
async def initialize(self) -> bool:
|
||||
try:
|
||||
self.endpoint_impls = None
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
|
@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
|
@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
if not self.route_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
|
@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
|
@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
@ -6,20 +6,23 @@
|
|||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class ApiEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
descriptive_name: str | None = None
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
routes = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
|
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
|||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
method = "delete"
|
||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(
|
||||
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||
)
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = endpoints
|
||||
apis[api] = routes
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_endpoint_impls(impls):
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||
routes = get_all_api_routes()
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
|
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
|
|||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
for api, api_routes in routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
for route in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
continue # Skip if only HEAD method is available
|
||||
method = available_methods[0].lower()
|
||||
if method not in route_impls:
|
||||
route_impls[method] = {}
|
||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
endpoint.descriptive_name or endpoint.route,
|
||||
route.path,
|
||||
)
|
||||
|
||||
return endpoint_impls
|
||||
return route_impls
|
||||
|
||||
|
||||
def find_matching_endpoint(method, path, endpoint_impls):
|
||||
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
endpoint_impls: A dictionary of endpoint implementations
|
||||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
|
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
|
|||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = endpoint_impls.get(method.lower())
|
||||
impls = route_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
|
@ -13,6 +14,7 @@ import ssl
|
|||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
|
@ -20,6 +22,7 @@ from typing import Annotated, Any
|
|||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .endpoints import get_all_api_endpoints
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request):
|
|||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
|
@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return endpoint
|
||||
return route_handler
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
|
@ -274,14 +278,14 @@ class TracingMiddleware:
|
|||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
all_routes = get_all_api_routes()
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
|
@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
|
|||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
for route in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||
impl_method = getattr(impl, route.name)
|
||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
getattr(app, method.lower())(route.path, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
endpoint.route,
|
||||
method.lower(),
|
||||
route.path,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue