fix: use hardcoded list and dictionary mapping for router registry

Replace dynamic import-based router discovery with an explicit hardcoded
list of APIs that have routers.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-24 09:56:39 +01:00
parent 03a31269ad
commit 49005f1a39
No known key found for this signature in database
5 changed files with 39 additions and 68 deletions

View file

@ -14,7 +14,7 @@ from typing import Any
from fastapi import FastAPI
from llama_stack.core.resolver import api_protocol_map
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
from llama_stack.core.server.fastapi_router_registry import build_router
from llama_stack_api import Api
from .state import _protocol_methods_cache
@ -77,10 +77,9 @@ def create_llama_stack_app() -> FastAPI:
],
)
# Include routers for APIs that have them (automatic discovery)
# Include routers for APIs that have them
protocols = api_protocol_map()
for api in protocols.keys():
if has_router(api):
# For OpenAPI generation, we don't need a real implementation
router = build_router(api, None)
if router:
@ -96,7 +95,7 @@ def create_llama_stack_app() -> FastAPI:
for api, routes in api_routes.items():
# Skip APIs that have routers - they're already included above
if has_router(api):
if build_router(api, None) is not None:
continue
for route, webmethod in routes:

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
from llama_stack.core.server.fastapi_router_registry import build_router
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack_api import (
Api,
@ -70,7 +70,7 @@ class DistributionInspectImpl(Inspect):
# Process webmethod-based routes (legacy)
for api, endpoints in all_endpoints.items():
# Skip APIs that have routers - they'll be processed separately
if has_router(api):
if build_router(api, None) is not None:
continue
provider_types = get_provider_types(api)
@ -113,9 +113,6 @@ class DistributionInspectImpl(Inspect):
protocols = api_protocol_map(external_apis)
for api in protocols.keys():
if not has_router(api):
continue
# For route inspection, we don't need a real implementation
router = build_router(api, None)
if not router:

View file

@ -6,11 +6,10 @@
"""Router utilities for FastAPI routers.
This module provides utilities to discover and create FastAPI routers from API packages.
Routers are automatically discovered by checking for fastapi_routes modules in each API package.
This module provides utilities to create FastAPI routers from API packages.
APIs with routers are explicitly listed here.
"""
import importlib
from typing import TYPE_CHECKING, Any, cast
from fastapi import APIRouter
@ -18,46 +17,30 @@ from fastapi import APIRouter
if TYPE_CHECKING:
from llama_stack_api.datatypes import Api
# Router factories for APIs that have FastAPI routers
# Add new APIs here as they are migrated to the router system
from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router
def has_router(api: "Api") -> bool:
"""Check if an API has a router factory in its fastapi_routes module.
Args:
api: The API enum value
Returns:
True if the API has a fastapi_routes module with a create_router function
"""
try:
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
return hasattr(routes_module, "create_router")
except (ImportError, AttributeError):
return False
_ROUTER_FACTORIES: dict[str, APIRouter] = {
"batches": create_batches_router,
}
def build_router(api: "Api", impl: Any) -> APIRouter | None:
"""Build a router for an API by combining its router factory with the implementation.
This function discovers the router factory from the API package's routes module
and calls it with the implementation to create the final router instance.
Args:
api: The API enum value
impl: The implementation instance for the API
Returns:
APIRouter if the API has a fastapi_routes module with create_router, None otherwise
APIRouter if the API has a router factory, None otherwise
"""
try:
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
if hasattr(routes_module, "create_router"):
router_factory = routes_module.create_router
# cast is safe here: mypy can't verify the return type statically because
# we're dynamically importing the module. However, all router factories in
# API packages are required to return APIRouter. If a router factory returns the wrong
# type, it will fail at runtime when app.include_router(router) is called
return cast(APIRouter, router_factory(impl))
except (ImportError, AttributeError):
pass
router_factory = _ROUTER_FACTORIES.get(api.value)
if router_factory is None:
return None
# cast is safe here: all router factories in API packages are required to return APIRouter.
# If a router factory returns the wrong type, it will fail at runtime when
# app.include_router(router) is called
return cast(APIRouter, router_factory(impl))

View file

@ -44,7 +44,7 @@ from llama_stack.core.request_headers import (
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
from llama_stack.core.server.fastapi_router_registry import build_router
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.core.stack import (
Stack,
@ -469,7 +469,6 @@ def create_app() -> StackApp:
api = Api(api_str)
# Try to discover and use a router factory from the API package
if has_router(api):
impl = impls[api]
router = build_router(api, impl)
if router:

View file

@ -6,11 +6,10 @@
from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.fastapi_router_registry import has_router
from llama_stack.core.server.fastapi_router_registry import _ROUTER_FACTORIES
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.core.telemetry.tracing import end_trace, start_trace
from llama_stack.log import get_logger
from llama_stack_api.datatypes import Api
from llama_stack_api.version import (
LLAMA_STACK_API_V1,
LLAMA_STACK_API_V1ALPHA,
@ -35,21 +34,15 @@ class TracingMiddleware:
"""Check if a path belongs to a router-based API.
Router-based APIs use FastAPI routers instead of the old webmethod system.
We need to check if the path matches any router-based API prefix.
Paths must start with a valid API level (v1, v1alpha, v1beta) followed by an API name.
"""
# Extract API name from path (e.g., /v1/batches -> batches)
# Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name}
parts = path.strip("/").split("/")
if len(parts) >= 2 and parts[0] in VALID_API_LEVELS:
api_name = parts[1]
try:
api = Api(api_name)
return has_router(api)
except (ValueError, KeyError):
# Not a known API or not router-based
return False
if len(parts) < 2 or parts[0] not in VALID_API_LEVELS:
return False
# Check directly if the API name is in the router factories list
return parts[1] in _ROUTER_FACTORIES
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)