chore: simplify route addition when calling inspect

https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-25 13:48:47 +01:00
parent ead9e63ef8
commit f330c8eb2f
No known key found for this signature in database
2 changed files with 97 additions and 56 deletions

View file

@ -10,8 +10,11 @@ from pydantic import BaseModel
from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.core.resolver import api_protocol_map from llama_stack.core.server.fastapi_router_registry import (
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router _ROUTER_FACTORIES,
build_fastapi_router,
get_router_routes,
)
from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.server.routes import get_all_api_routes
from llama_stack_api import ( from llama_stack_api import (
Api, Api,
@ -46,6 +49,7 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
# Helper function to determine if a route should be included based on api_filter # Helper function to determine if a route should be included based on api_filter
# TODO: remove this once we've migrated all APIs to FastAPI routers
def should_include_route(webmethod) -> bool: def should_include_route(webmethod) -> bool:
if api_filter is None: if api_filter is None:
# Default: only non-deprecated APIs # Default: only non-deprecated APIs
@ -57,40 +61,15 @@ class DistributionInspectImpl(Inspect):
# Filter by API level (non-deprecated routes only) # Filter by API level (non-deprecated routes only)
return not webmethod.deprecated and webmethod.level == api_filter return not webmethod.deprecated and webmethod.level == api_filter
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
# Helper function to get provider types for an API # Helper function to get provider types for an API
def get_provider_types(api: Api) -> list[str]: def _get_provider_types(api: Api) -> list[str]:
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:
return [] # These APIs don't have "real" providers they're internal to the stack return [] # These APIs don't have "real" providers they're internal to the stack
providers = run_config.providers.get(api.value, []) providers = run_config.providers.get(api.value, [])
return [p.provider_type for p in providers] if providers else [] return [p.provider_type for p in providers] if providers else []
# Process webmethod-based routes (legacy)
for api, endpoints in all_endpoints.items():
# Skip APIs that have routers - they'll be processed separately
if build_fastapi_router(api, None) is not None:
continue
provider_types = get_provider_types(api)
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"] or provider_types:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=provider_types,
)
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)
# Helper function to determine if a router route should be included based on api_filter # Helper function to determine if a router route should be included based on api_filter
def should_include_router_route(route, router_prefix: str | None) -> bool: def _should_include_router_route(route, router_prefix: str | None) -> bool:
"""Check if a router-based route should be included based on api_filter.""" """Check if a router-based route should be included based on api_filter."""
# Check deprecated status # Check deprecated status
route_deprecated = getattr(route, "deprecated", False) or False route_deprecated = getattr(route, "deprecated", False) or False
@ -109,37 +88,60 @@ class DistributionInspectImpl(Inspect):
return not route_deprecated and prefix_level == api_filter return not route_deprecated and prefix_level == api_filter
return not route_deprecated return not route_deprecated
protocols = api_protocol_map(external_apis) ret = []
for api in protocols.keys(): external_apis = load_external_apis(run_config)
# For route inspection, we don't need a real implementation all_endpoints = get_all_api_routes(external_apis)
router = build_fastapi_router(api, None)
if not router:
continue
provider_types = get_provider_types(api)
# Only include if there are providers (or it's a special API)
if api.value in ["providers", "inspect"] or provider_types:
router_prefix = getattr(router, "prefix", None)
for route in router.routes:
# Extract HTTP methods from the route
# FastAPI routes have methods as a set
if hasattr(route, "methods") and route.methods:
methods = {m for m in route.methods if m != "HEAD"}
if methods and should_include_router_route(route, router_prefix):
# FastAPI already combines router prefix with route path
# Only APIRoute has a path attribute, use getattr to safely access it
path = getattr(route, "path", None)
if path is None:
continue
# Process routes from APIs with FastAPI routers
for api_name in _ROUTER_FACTORIES.keys():
api = Api(api_name)
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
if router:
router_routes = get_router_routes(router)
for route in router_routes:
if _should_include_router_route(route, router.prefix):
ret.append( ret.append(
RouteInfo( RouteInfo(
route=path, route=route.path,
method=next(iter(methods)), method=next(iter([m for m in route.methods if m != "HEAD"])),
provider_types=provider_types, provider_types=_get_provider_types(api),
) )
) )
# Process routes from legacy webmethod-based APIs
for api, endpoints in all_endpoints.items():
# Skip APIs that have routers (already processed above)
if api.value in _ROUTER_FACTORIES:
continue
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)
return ListRoutesResponse(data=ret) return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo: async def health(self) -> HealthInfo:

View file

@ -14,6 +14,8 @@ from collections.abc import Callable
from typing import Any, cast from typing import Any, cast
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.routing import APIRoute
from starlette.routing import Route
# Router factories for APIs that have FastAPI routers # Router factories for APIs that have FastAPI routers
# Add new APIs here as they are migrated to the router system # Add new APIs here as they are migrated to the router system
@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
# If a router factory returns the wrong type, it will fail at runtime when # If a router factory returns the wrong type, it will fail at runtime when
# app.include_router(router) is called # app.include_router(router) is called
return cast(APIRouter, router_factory(impl)) return cast(APIRouter, router_factory(impl))
def get_router_routes(router: APIRouter) -> list[Route]:
"""Extract routes from a FastAPI router.
Args:
router: The FastAPI router to extract routes from
Returns:
List of Route objects from the router
"""
routes = []
for route in router.routes:
# FastAPI routers use APIRoute objects, which have path and methods attributes
if isinstance(route, APIRoute):
# Combine router prefix with route path
routes.append(
Route(
path=route.path,
methods=route.methods,
name=route.name,
endpoint=route.endpoint,
)
)
elif isinstance(route, Route):
# Fallback for regular Starlette Route objects
routes.append(
Route(
path=route.path,
methods=route.methods,
name=route.name,
endpoint=route.endpoint,
)
)
return routes