mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: simplify route addition when calling inspect
https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
ead9e63ef8
commit
f330c8eb2f
2 changed files with 97 additions and 56 deletions
|
|
@ -10,8 +10,11 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
|
||||
from llama_stack.core.server.fastapi_router_registry import (
|
||||
_ROUTER_FACTORIES,
|
||||
build_fastapi_router,
|
||||
get_router_routes,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
|
|
@ -46,6 +49,7 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
# TODO: remove this once we've migrated all APIs to FastAPI routers
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated APIs
|
||||
|
|
@ -57,40 +61,15 @@ class DistributionInspectImpl(Inspect):
|
|||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
||||
# Helper function to get provider types for an API
|
||||
def get_provider_types(api: Api) -> list[str]:
|
||||
def _get_provider_types(api: Api) -> list[str]:
|
||||
if api.value in ["providers", "inspect"]:
|
||||
return [] # These APIs don't have "real" providers they're internal to the stack
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
return [p.provider_type for p in providers] if providers else []
|
||||
|
||||
# Process webmethod-based routes (legacy)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Skip APIs that have routers - they'll be processed separately
|
||||
if build_fastapi_router(api, None) is not None:
|
||||
continue
|
||||
|
||||
provider_types = get_provider_types(api)
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"] or provider_types:
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=provider_types,
|
||||
)
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
|
||||
# Helper function to determine if a router route should be included based on api_filter
|
||||
def should_include_router_route(route, router_prefix: str | None) -> bool:
|
||||
def _should_include_router_route(route, router_prefix: str | None) -> bool:
|
||||
"""Check if a router-based route should be included based on api_filter."""
|
||||
# Check deprecated status
|
||||
route_deprecated = getattr(route, "deprecated", False) or False
|
||||
|
|
@ -109,37 +88,60 @@ class DistributionInspectImpl(Inspect):
|
|||
return not route_deprecated and prefix_level == api_filter
|
||||
return not route_deprecated
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
for api in protocols.keys():
|
||||
# For route inspection, we don't need a real implementation
|
||||
router = build_fastapi_router(api, None)
|
||||
if not router:
|
||||
continue
|
||||
|
||||
provider_types = get_provider_types(api)
|
||||
# Only include if there are providers (or it's a special API)
|
||||
if api.value in ["providers", "inspect"] or provider_types:
|
||||
router_prefix = getattr(router, "prefix", None)
|
||||
for route in router.routes:
|
||||
# Extract HTTP methods from the route
|
||||
# FastAPI routes have methods as a set
|
||||
if hasattr(route, "methods") and route.methods:
|
||||
methods = {m for m in route.methods if m != "HEAD"}
|
||||
if methods and should_include_router_route(route, router_prefix):
|
||||
# FastAPI already combines router prefix with route path
|
||||
# Only APIRoute has a path attribute, use getattr to safely access it
|
||||
path = getattr(route, "path", None)
|
||||
if path is None:
|
||||
continue
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
||||
# Process routes from APIs with FastAPI routers
|
||||
for api_name in _ROUTER_FACTORIES.keys():
|
||||
api = Api(api_name)
|
||||
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
|
||||
if router:
|
||||
router_routes = get_router_routes(router)
|
||||
for route in router_routes:
|
||||
if _should_include_router_route(route, router.prefix):
|
||||
ret.append(
|
||||
RouteInfo(
|
||||
route=path,
|
||||
method=next(iter(methods)),
|
||||
provider_types=provider_types,
|
||||
route=route.path,
|
||||
method=next(iter([m for m in route.methods if m != "HEAD"])),
|
||||
provider_types=_get_provider_types(api),
|
||||
)
|
||||
)
|
||||
|
||||
# Process routes from legacy webmethod-based APIs
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Skip APIs that have routers (already processed above)
|
||||
if api.value in _ROUTER_FACTORIES:
|
||||
continue
|
||||
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
if providers: # Only process if there are providers for this API
|
||||
ret.extend(
|
||||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
|
||||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from collections.abc import Callable
|
|||
from typing import Any, cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
|
|
@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
|
|||
# If a router factory returns the wrong type, it will fail at runtime when
|
||||
# app.include_router(router) is called
|
||||
return cast(APIRouter, router_factory(impl))
|
||||
|
||||
|
||||
def get_router_routes(router: APIRouter) -> list[Route]:
|
||||
"""Extract routes from a FastAPI router.
|
||||
|
||||
Args:
|
||||
router: The FastAPI router to extract routes from
|
||||
|
||||
Returns:
|
||||
List of Route objects from the router
|
||||
"""
|
||||
routes = []
|
||||
|
||||
for route in router.routes:
|
||||
# FastAPI routers use APIRoute objects, which have path and methods attributes
|
||||
if isinstance(route, APIRoute):
|
||||
# Combine router prefix with route path
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
elif isinstance(route, Route):
|
||||
# Fallback for regular Starlette Route objects
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue