mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: simplify route addition when calling inspect
https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
ead9e63ef8
commit
f330c8eb2f
2 changed files with 97 additions and 56 deletions
|
|
@ -10,8 +10,11 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.server.fastapi_router_registry import (
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
|
_ROUTER_FACTORIES,
|
||||||
|
build_fastapi_router,
|
||||||
|
get_router_routes,
|
||||||
|
)
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Api,
|
Api,
|
||||||
|
|
@ -46,6 +49,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
# Helper function to determine if a route should be included based on api_filter
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
|
# TODO: remove this once we've migrated all APIs to FastAPI routers
|
||||||
def should_include_route(webmethod) -> bool:
|
def should_include_route(webmethod) -> bool:
|
||||||
if api_filter is None:
|
if api_filter is None:
|
||||||
# Default: only non-deprecated APIs
|
# Default: only non-deprecated APIs
|
||||||
|
|
@ -57,40 +61,15 @@ class DistributionInspectImpl(Inspect):
|
||||||
# Filter by API level (non-deprecated routes only)
|
# Filter by API level (non-deprecated routes only)
|
||||||
return not webmethod.deprecated and webmethod.level == api_filter
|
return not webmethod.deprecated and webmethod.level == api_filter
|
||||||
|
|
||||||
ret = []
|
|
||||||
external_apis = load_external_apis(run_config)
|
|
||||||
all_endpoints = get_all_api_routes(external_apis)
|
|
||||||
|
|
||||||
# Helper function to get provider types for an API
|
# Helper function to get provider types for an API
|
||||||
def get_provider_types(api: Api) -> list[str]:
|
def _get_provider_types(api: Api) -> list[str]:
|
||||||
if api.value in ["providers", "inspect"]:
|
if api.value in ["providers", "inspect"]:
|
||||||
return [] # These APIs don't have "real" providers they're internal to the stack
|
return [] # These APIs don't have "real" providers they're internal to the stack
|
||||||
providers = run_config.providers.get(api.value, [])
|
providers = run_config.providers.get(api.value, [])
|
||||||
return [p.provider_type for p in providers] if providers else []
|
return [p.provider_type for p in providers] if providers else []
|
||||||
|
|
||||||
# Process webmethod-based routes (legacy)
|
|
||||||
for api, endpoints in all_endpoints.items():
|
|
||||||
# Skip APIs that have routers - they'll be processed separately
|
|
||||||
if build_fastapi_router(api, None) is not None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
provider_types = get_provider_types(api)
|
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
|
||||||
if api.value in ["providers", "inspect"] or provider_types:
|
|
||||||
ret.extend(
|
|
||||||
[
|
|
||||||
RouteInfo(
|
|
||||||
route=e.path,
|
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
|
||||||
provider_types=provider_types,
|
|
||||||
)
|
|
||||||
for e, webmethod in endpoints
|
|
||||||
if e.methods is not None and should_include_route(webmethod)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Helper function to determine if a router route should be included based on api_filter
|
# Helper function to determine if a router route should be included based on api_filter
|
||||||
def should_include_router_route(route, router_prefix: str | None) -> bool:
|
def _should_include_router_route(route, router_prefix: str | None) -> bool:
|
||||||
"""Check if a router-based route should be included based on api_filter."""
|
"""Check if a router-based route should be included based on api_filter."""
|
||||||
# Check deprecated status
|
# Check deprecated status
|
||||||
route_deprecated = getattr(route, "deprecated", False) or False
|
route_deprecated = getattr(route, "deprecated", False) or False
|
||||||
|
|
@ -109,36 +88,59 @@ class DistributionInspectImpl(Inspect):
|
||||||
return not route_deprecated and prefix_level == api_filter
|
return not route_deprecated and prefix_level == api_filter
|
||||||
return not route_deprecated
|
return not route_deprecated
|
||||||
|
|
||||||
protocols = api_protocol_map(external_apis)
|
ret = []
|
||||||
for api in protocols.keys():
|
external_apis = load_external_apis(run_config)
|
||||||
# For route inspection, we don't need a real implementation
|
all_endpoints = get_all_api_routes(external_apis)
|
||||||
router = build_fastapi_router(api, None)
|
|
||||||
if not router:
|
# Process routes from APIs with FastAPI routers
|
||||||
|
for api_name in _ROUTER_FACTORIES.keys():
|
||||||
|
api = Api(api_name)
|
||||||
|
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
|
||||||
|
if router:
|
||||||
|
router_routes = get_router_routes(router)
|
||||||
|
for route in router_routes:
|
||||||
|
if _should_include_router_route(route, router.prefix):
|
||||||
|
ret.append(
|
||||||
|
RouteInfo(
|
||||||
|
route=route.path,
|
||||||
|
method=next(iter([m for m in route.methods if m != "HEAD"])),
|
||||||
|
provider_types=_get_provider_types(api),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process routes from legacy webmethod-based APIs
|
||||||
|
for api, endpoints in all_endpoints.items():
|
||||||
|
# Skip APIs that have routers (already processed above)
|
||||||
|
if api.value in _ROUTER_FACTORIES:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
provider_types = get_provider_types(api)
|
# Always include provider and inspect APIs, filter others based on run config
|
||||||
# Only include if there are providers (or it's a special API)
|
if api.value in ["providers", "inspect"]:
|
||||||
if api.value in ["providers", "inspect"] or provider_types:
|
ret.extend(
|
||||||
router_prefix = getattr(router, "prefix", None)
|
[
|
||||||
for route in router.routes:
|
RouteInfo(
|
||||||
# Extract HTTP methods from the route
|
route=e.path,
|
||||||
# FastAPI routes have methods as a set
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
if hasattr(route, "methods") and route.methods:
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
methods = {m for m in route.methods if m != "HEAD"}
|
)
|
||||||
if methods and should_include_router_route(route, router_prefix):
|
for e, webmethod in endpoints
|
||||||
# FastAPI already combines router prefix with route path
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
# Only APIRoute has a path attribute, use getattr to safely access it
|
]
|
||||||
path = getattr(route, "path", None)
|
)
|
||||||
if path is None:
|
else:
|
||||||
continue
|
providers = run_config.providers.get(api.value, [])
|
||||||
|
if providers: # Only process if there are providers for this API
|
||||||
ret.append(
|
ret.extend(
|
||||||
RouteInfo(
|
[
|
||||||
route=path,
|
RouteInfo(
|
||||||
method=next(iter(methods)),
|
route=e.path,
|
||||||
provider_types=provider_types,
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
)
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
|
for e, webmethod in endpoints
|
||||||
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return ListRoutesResponse(data=ret)
|
return ListRoutesResponse(data=ret)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
# Router factories for APIs that have FastAPI routers
|
# Router factories for APIs that have FastAPI routers
|
||||||
# Add new APIs here as they are migrated to the router system
|
# Add new APIs here as they are migrated to the router system
|
||||||
|
|
@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||||
# If a router factory returns the wrong type, it will fail at runtime when
|
# If a router factory returns the wrong type, it will fail at runtime when
|
||||||
# app.include_router(router) is called
|
# app.include_router(router) is called
|
||||||
return cast(APIRouter, router_factory(impl))
|
return cast(APIRouter, router_factory(impl))
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_routes(router: APIRouter) -> list[Route]:
|
||||||
|
"""Extract routes from a FastAPI router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router: The FastAPI router to extract routes from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Route objects from the router
|
||||||
|
"""
|
||||||
|
routes = []
|
||||||
|
|
||||||
|
for route in router.routes:
|
||||||
|
# FastAPI routers use APIRoute objects, which have path and methods attributes
|
||||||
|
if isinstance(route, APIRoute):
|
||||||
|
# Combine router prefix with route path
|
||||||
|
routes.append(
|
||||||
|
Route(
|
||||||
|
path=route.path,
|
||||||
|
methods=route.methods,
|
||||||
|
name=route.name,
|
||||||
|
endpoint=route.endpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(route, Route):
|
||||||
|
# Fallback for regular Starlette Route objects
|
||||||
|
routes.append(
|
||||||
|
Route(
|
||||||
|
path=route.path,
|
||||||
|
methods=route.methods,
|
||||||
|
name=route.name,
|
||||||
|
endpoint=route.endpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return routes
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue