From f330c8eb2f51567bf44b12e618d23abc81a63629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 13:48:47 +0100 Subject: [PATCH] chore: simplify route addition when calling inspect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918 Signed-off-by: Sébastien Han --- src/llama_stack/core/inspect.py | 114 +++++++++--------- .../core/server/fastapi_router_registry.py | 39 ++++++ 2 files changed, 97 insertions(+), 56 deletions(-) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 28d23b815..45cab2970 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,8 +10,11 @@ from pydantic import BaseModel from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.resolver import api_protocol_map -from llama_stack.core.server.fastapi_router_registry import build_fastapi_router +from llama_stack.core.server.fastapi_router_registry import ( + _ROUTER_FACTORIES, + build_fastapi_router, + get_router_routes, +) from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, @@ -46,6 +49,7 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config # Helper function to determine if a route should be included based on api_filter + # TODO: remove this once we've migrated all APIs to FastAPI routers def should_include_route(webmethod) -> bool: if api_filter is None: # Default: only non-deprecated APIs @@ -57,40 +61,15 @@ class DistributionInspectImpl(Inspect): # Filter by API level (non-deprecated routes only) return not webmethod.deprecated and webmethod.level == api_filter - ret = [] - external_apis = load_external_apis(run_config) - all_endpoints = get_all_api_routes(external_apis) - # Helper function to get provider types for an API - def get_provider_types(api: Api) -> list[str]: + def _get_provider_types(api: Api) -> list[str]: if api.value in ["providers", "inspect"]: return [] # These APIs don't have "real" providers they're internal to the stack providers = run_config.providers.get(api.value, []) return [p.provider_type for p in providers] if providers else [] - # Process webmethod-based routes (legacy) - for api, endpoints in all_endpoints.items(): - # Skip APIs that have routers - they'll be processed separately - if build_fastapi_router(api, None) is not None: - continue - - provider_types = get_provider_types(api) - # Always include provider and inspect APIs, filter others based on run config - if api.value in ["providers", "inspect"] or provider_types: - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=provider_types, - ) - for e, webmethod in endpoints - if e.methods is not None and should_include_route(webmethod) - ] - ) - # Helper function to determine if a router route should be included based on api_filter - def should_include_router_route(route, router_prefix: str | None) -> bool: + def _should_include_router_route(route, router_prefix: str | None) -> bool: """Check if a router-based route should be included based on api_filter.""" # Check deprecated status route_deprecated = getattr(route, "deprecated", False) or False @@ -109,36 +88,59 @@ class DistributionInspectImpl(Inspect): return not route_deprecated and prefix_level == api_filter return not route_deprecated - protocols = api_protocol_map(external_apis) - for api in protocols.keys(): - # For route inspection, we don't need a real implementation - router = build_fastapi_router(api, None) - if not router: + ret = [] + external_apis = load_external_apis(run_config) + all_endpoints = get_all_api_routes(external_apis) + + # Process routes from APIs with FastAPI routers + for api_name in _ROUTER_FACTORIES.keys(): + api = Api(api_name) + router = build_fastapi_router(api, None) # we don't need the impl here, just the routes + if router: + router_routes = get_router_routes(router) + for route in router_routes: + if _should_include_router_route(route, router.prefix): + ret.append( + RouteInfo( + route=route.path, + method=next(iter([m for m in route.methods if m != "HEAD"])), + provider_types=_get_provider_types(api), + ) + ) + + # Process routes from legacy webmethod-based APIs + for api, endpoints in all_endpoints.items(): + # Skip APIs that have routers (already processed above) + if api.value in _ROUTER_FACTORIES: continue - provider_types = get_provider_types(api) - # Only include if there are providers (or it's a special API) - if api.value in ["providers", "inspect"] or provider_types: - router_prefix = getattr(router, "prefix", None) - for route in router.routes: - # Extract HTTP methods from the route - # FastAPI routes have methods as a set - if hasattr(route, "methods") and route.methods: - methods = {m for m in route.methods if m != "HEAD"} - if methods and should_include_router_route(route, router_prefix): - # FastAPI already combines router prefix with route path - # Only APIRoute has a path attribute, use getattr to safely access it - path = getattr(route, "path", None) - if path is None: - continue - - ret.append( - RouteInfo( - route=path, - method=next(iter(methods)), - provider_types=provider_types, - ) + # Always include provider and inspect APIs, filter others based on run config + if api.value in ["providers", "inspect"]: + ret.extend( + [ + RouteInfo( + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), + provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + ) + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) + ] + ) + else: + providers = run_config.providers.get(api.value, []) + if providers: # Only process if there are providers for this API + ret.extend( + [ + RouteInfo( + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), + provider_types=[p.provider_type for p in providers], ) + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) + ] + ) return ListRoutesResponse(data=ret) diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 178622853..84f41693d 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -14,6 +14,8 @@ from collections.abc import Callable from typing import Any, cast from fastapi import APIRouter +from fastapi.routing import APIRoute +from starlette.routing import Route # Router factories for APIs that have FastAPI routers # Add new APIs here as they are migrated to the router system @@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None: # If a router factory returns the wrong type, it will fail at runtime when # app.include_router(router) is called return cast(APIRouter, router_factory(impl)) + + +def get_router_routes(router: APIRouter) -> list[Route]: + """Extract routes from a FastAPI router. + + Args: + router: The FastAPI router to extract routes from + + Returns: + List of Route objects from the router + """ + routes = [] + + for route in router.routes: + # FastAPI routers use APIRoute objects, which have path and methods attributes + if isinstance(route, APIRoute): + # Combine router prefix with route path + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + elif isinstance(route, Route): + # Fallback for regular Starlette Route objects + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + + return routes