refactor: enhance route system with WebMethod metadata

- Add webmethod metadata extraction in routes.py
- Update route matching to return WebMethod objects
- Enhance tracing in library_client.py to use WebMethod info
- Add user_from_scope utility function
This commit is contained in:
Eric Huang 2025-07-24 10:49:50 -07:00
parent 1463b79218
commit e0206795d9
4 changed files with 29 additions and 21 deletions

View file

@ -42,8 +42,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_routes() api_to_routes = get_all_api_routes()
for api, endpoints in all_endpoints.items(): for api, endpoints in api_to_routes.items():
# Always include provider and inspect APIs, filter others based on run config # Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:
ret.extend( ret.extend(
@ -53,7 +53,7 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])), method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
) )
for e in endpoints for e, _ in endpoints
if e.methods is not None if e.methods is not None
] ]
) )
@ -67,7 +67,7 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])), method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e, _ in endpoints
if e.methods is not None if e.methods is not None
] ]
) )

View file

@ -359,13 +359,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body, field_names = self._handle_file_uploads(options, body) body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
@ -415,12 +417,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params, route = find_matching_route(options.method, path, self.route_impls) func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"}) trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen(): async def gen():
try: try:
@ -475,7 +478,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
exclude_params = exclude_params or set() exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls) func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -16,13 +16,14 @@ from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any] EndpointFunc = Callable[..., Any]
PathParams = dict[str, str] PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str] RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo] PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl] RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str] RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -31,7 +32,7 @@ def toolgroup_protocol_map():
} }
def get_all_api_routes() -> dict[Api, list[Route]]: def get_all_api_routes() -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()
@ -65,7 +66,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
else: else:
http_method = hdrs.METH_POST http_method = hdrs.METH_POST
routes.append( routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None) (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object ) # setting endpoint to None since don't use a Router object
apis[api] = routes apis[api] = routes
@ -74,7 +75,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes() api_to_routes = get_all_api_routes()
route_impls: RouteImpls = {} route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str: def _convert_path_to_regex(path: str) -> str:
@ -88,10 +89,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
return f"^{pattern}$" return f"^{pattern}$"
for api, api_routes in routes.items(): for api, api_routes in api_to_routes.items():
if api not in impls: if api not in impls:
continue continue
for route in api_routes: for route, webmethod in api_routes:
impl = impls[api] impl = impls[api]
func = getattr(impl, route.name) func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD # Get the first (and typically only) method from the set, filtering out HEAD
@ -104,6 +105,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
route_impls[method][_convert_path_to_regex(route.path)] = ( route_impls[method][_convert_path_to_regex(route.path)] = (
func, func,
route.path, route.path,
webmethod,
) )
return route_impls return route_impls
@ -118,7 +120,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
route_impls: A dictionary of endpoint implementations route_impls: A dictionary of endpoint implementations
Returns: Returns:
A tuple of (endpoint_function, path_params, descriptive_name) A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises: Raises:
ValueError: If no matching endpoint is found ValueError: If no matching endpoint is found
@ -127,11 +129,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
if not impls: if not impls:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items(): for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path) match = re.match(regex, path)
if match: if match:
# Extract named groups from the regex match # Extract named groups from the regex match
path_params = match.groupdict() path_params = match.groupdict()
return func, path_params, descriptive_name return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")

View file

@ -304,7 +304,9 @@ class TracingMiddleware:
self.route_impls = initialize_route_impls(self.impls) self.route_impls = initialize_route_impls(self.impls)
try: try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) _, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError: except ValueError:
# If no matching endpoint is found, pass through to FastAPI # If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
@ -321,6 +323,7 @@ class TracingMiddleware:
if tracestate: if tracestate:
trace_attributes["tracestate"] = tracestate trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes) trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message): async def send_with_trace_id(message):
@ -504,7 +507,7 @@ def main(args: argparse.Namespace | None = None):
routes = all_routes[api] routes = all_routes[api]
impl = impls[api] impl = impls[api]
for route in routes: for route, _ in routes:
if not hasattr(impl, route.name): if not hasattr(impl, route.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!") raise ValueError(f"Could not find method {route.name} on {impl}!")