From e0206795d9ce7f38daf6aab41bc8342a3da058fd Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 24 Jul 2025 10:49:50 -0700 Subject: [PATCH] refactor: enhance route system with WebMethod metadata - Add webmethod metadata extraction in routes.py - Update route matching to return WebMethod objects - Enhance tracing in library_client.py to use WebMethod info - Add user_from_scope utility function --- llama_stack/distribution/inspect.py | 8 ++++---- llama_stack/distribution/library_client.py | 13 ++++++++----- llama_stack/distribution/server/routes.py | 22 ++++++++++++---------- llama_stack/distribution/server/server.py | 7 +++++-- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index d6a598982..73339ea0a 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -42,8 +42,8 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config ret = [] - all_endpoints = get_all_api_routes() - for api, endpoints in all_endpoints.items(): + api_to_routes = get_all_api_routes() + for api, endpoints in api_to_routes.items(): # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: ret.extend( @@ -53,7 +53,7 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[], # These APIs don't have "real" providers - they're internal to the stack ) - for e in endpoints + for e, _ in endpoints if e.methods is not None ] ) @@ -67,7 +67,7 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[p.provider_type for p in providers], ) - for e in endpoints + for e, _ in endpoints if e.methods is not None ] ) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5044fd8c8..07949aea7 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -359,13 +359,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} - matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) + matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params body, field_names = self._handle_file_uploads(options, body) body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) - await start_trace(route, {"__location__": "library_client"}) + + trace_path = webmethod.descriptive_name or route_path + await start_trace(trace_path, {"__location__": "library_client"}) try: result = await matched_func(**body) finally: @@ -415,12 +417,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): path = options.url body = options.params or {} body |= options.json_data or {} - func, path_params, route = find_matching_route(options.method, path, self.route_impls) + func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params body = self._convert_body(path, options.method, body) - await start_trace(route, {"__location__": "library_client"}) + trace_path = webmethod.descriptive_name or route_path + await start_trace(trace_path, {"__location__": "library_client"}) async def gen(): try: @@ -475,7 +478,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): exclude_params = exclude_params or set() - func, _, _ = find_matching_route(method, path, self.route_impls) + func, _, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py index ea66fec5a..4f67e9c3b 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/routes.py @@ -16,13 +16,14 @@ from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map from llama_stack.providers.datatypes import Api +from llama_stack.schema_utils import WebMethod EndpointFunc = Callable[..., Any] PathParams = dict[str, str] -RouteInfo = tuple[EndpointFunc, str] +RouteInfo = tuple[EndpointFunc, str, WebMethod] PathImpl = dict[str, RouteInfo] RouteImpls = dict[str, PathImpl] -RouteMatch = tuple[EndpointFunc, PathParams, str] +RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod] def toolgroup_protocol_map(): @@ -31,7 +32,7 @@ def toolgroup_protocol_map(): } -def get_all_api_routes() -> dict[Api, list[Route]]: +def get_all_api_routes() -> dict[Api, list[tuple[Route, WebMethod]]]: apis = {} protocols = api_protocol_map() @@ -65,7 +66,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]: else: http_method = hdrs.METH_POST routes.append( - Route(path=path, methods=[http_method], name=name, endpoint=None) + (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod) ) # setting endpoint to None since don't use a Router object apis[api] = routes @@ -74,7 +75,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]: def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: - routes = get_all_api_routes() + api_to_routes = get_all_api_routes() route_impls: RouteImpls = {} def _convert_path_to_regex(path: str) -> str: @@ -88,10 +89,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: return f"^{pattern}$" - for api, api_routes in routes.items(): + for api, api_routes in api_to_routes.items(): if api not in impls: continue - for route in api_routes: + for route, webmethod in api_routes: impl = impls[api] func = getattr(impl, route.name) # Get the first (and typically only) method from the set, filtering out HEAD @@ -104,6 +105,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: route_impls[method][_convert_path_to_regex(route.path)] = ( func, route.path, + webmethod, ) return route_impls @@ -118,7 +120,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout route_impls: A dictionary of endpoint implementations Returns: - A tuple of (endpoint_function, path_params, descriptive_name) + A tuple of (endpoint_function, path_params, route_path, webmethod_metadata) Raises: ValueError: If no matching endpoint is found @@ -127,11 +129,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout if not impls: raise ValueError(f"No endpoint found for {path}") - for regex, (func, descriptive_name) in impls.items(): + for regex, (func, route_path, webmethod) in impls.items(): match = re.match(regex, path) if match: # Extract named groups from the regex match path_params = match.groupdict() - return func, path_params, descriptive_name + return func, path_params, route_path, webmethod raise ValueError(f"No endpoint found for {path}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f05c4ad83..41ca05b2c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -304,7 +304,9 @@ class TracingMiddleware: self.route_impls = initialize_route_impls(self.impls) try: - _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) + _, _, route_path, webmethod = find_matching_route( + scope.get("method", hdrs.METH_GET), path, self.route_impls + ) except ValueError: # If no matching endpoint is found, pass through to FastAPI logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") @@ -321,6 +323,7 @@ class TracingMiddleware: if tracestate: trace_attributes["tracestate"] = tracestate + trace_path = webmethod.descriptive_name or route_path trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): @@ -504,7 +507,7 @@ def main(args: argparse.Namespace | None = None): routes = all_routes[api] impl = impls[api] - for route in routes: + for route, _ in routes: if not hasattr(impl, route.name): # ideally this should be a typing violation already raise ValueError(f"Could not find method {route.name} on {impl}!")