mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: use hardcoded list and dictionary mapping for router registry
Replace dynamic import-based router discovery with an explicit hardcoded list of APIs that have routers. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
03a31269ad
commit
49005f1a39
5 changed files with 39 additions and 68 deletions
|
|
@ -14,7 +14,7 @@ from typing import Any
|
|||
from fastapi import FastAPI
|
||||
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .state import _protocol_methods_cache
|
||||
|
|
@ -77,14 +77,13 @@ def create_llama_stack_app() -> FastAPI:
|
|||
],
|
||||
)
|
||||
|
||||
# Include routers for APIs that have them (automatic discovery)
|
||||
# Include routers for APIs that have them
|
||||
protocols = api_protocol_map()
|
||||
for api in protocols.keys():
|
||||
if has_router(api):
|
||||
# For OpenAPI generation, we don't need a real implementation
|
||||
router = build_router(api, None)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
# For OpenAPI generation, we don't need a real implementation
|
||||
router = build_router(api, None)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
|
||||
# Get all API routes (for legacy webmethod-based routes)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -96,7 +95,7 @@ def create_llama_stack_app() -> FastAPI:
|
|||
|
||||
for api, routes in api_routes.items():
|
||||
# Skip APIs that have routers - they're already included above
|
||||
if has_router(api):
|
||||
if build_router(api, None) is not None:
|
||||
continue
|
||||
|
||||
for route, webmethod in routes:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
|
|
@ -70,7 +70,7 @@ class DistributionInspectImpl(Inspect):
|
|||
# Process webmethod-based routes (legacy)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Skip APIs that have routers - they'll be processed separately
|
||||
if has_router(api):
|
||||
if build_router(api, None) is not None:
|
||||
continue
|
||||
|
||||
provider_types = get_provider_types(api)
|
||||
|
|
@ -113,9 +113,6 @@ class DistributionInspectImpl(Inspect):
|
|||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
for api in protocols.keys():
|
||||
if not has_router(api):
|
||||
continue
|
||||
|
||||
# For route inspection, we don't need a real implementation
|
||||
router = build_router(api, None)
|
||||
if not router:
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
"""Router utilities for FastAPI routers.
|
||||
|
||||
This module provides utilities to discover and create FastAPI routers from API packages.
|
||||
Routers are automatically discovered by checking for fastapi_routes modules in each API package.
|
||||
This module provides utilities to create FastAPI routers from API packages.
|
||||
APIs with routers are explicitly listed here.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
|
@ -18,46 +17,30 @@ from fastapi import APIRouter
|
|||
if TYPE_CHECKING:
|
||||
from llama_stack_api.datatypes import Api
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router
|
||||
|
||||
def has_router(api: "Api") -> bool:
|
||||
"""Check if an API has a router factory in its fastapi_routes module.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
|
||||
Returns:
|
||||
True if the API has a fastapi_routes module with a create_router function
|
||||
"""
|
||||
try:
|
||||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
||||
return hasattr(routes_module, "create_router")
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
_ROUTER_FACTORIES: dict[str, APIRouter] = {
|
||||
"batches": create_batches_router,
|
||||
}
|
||||
|
||||
|
||||
def build_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||
"""Build a router for an API by combining its router factory with the implementation.
|
||||
|
||||
This function discovers the router factory from the API package's routes module
|
||||
and calls it with the implementation to create the final router instance.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
impl: The implementation instance for the API
|
||||
|
||||
Returns:
|
||||
APIRouter if the API has a fastapi_routes module with create_router, None otherwise
|
||||
APIRouter if the API has a router factory, None otherwise
|
||||
"""
|
||||
try:
|
||||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
||||
if hasattr(routes_module, "create_router"):
|
||||
router_factory = routes_module.create_router
|
||||
# cast is safe here: mypy can't verify the return type statically because
|
||||
# we're dynamically importing the module. However, all router factories in
|
||||
# API packages are required to return APIRouter. If a router factory returns the wrong
|
||||
# type, it will fail at runtime when app.include_router(router) is called
|
||||
return cast(APIRouter, router_factory(impl))
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
router_factory = _ROUTER_FACTORIES.get(api.value)
|
||||
if router_factory is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
# cast is safe here: all router factories in API packages are required to return APIRouter.
|
||||
# If a router factory returns the wrong type, it will fail at runtime when
|
||||
# app.include_router(router) is called
|
||||
return cast(APIRouter, router_factory(impl))
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
|
|
@ -469,13 +469,12 @@ def create_app() -> StackApp:
|
|||
api = Api(api_str)
|
||||
|
||||
# Try to discover and use a router factory from the API package
|
||||
if has_router(api):
|
||||
impl = impls[api]
|
||||
router = build_router(api, impl)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
logger.debug(f"Registered router for {api} API")
|
||||
continue
|
||||
impl = impls[api]
|
||||
router = build_router(api, impl)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
logger.debug(f"Registered router for {api} API")
|
||||
continue
|
||||
|
||||
# Fall back to old webmethod-based route discovery until the migration is complete
|
||||
impl = impls[api]
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@
|
|||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.fastapi_router_registry import has_router
|
||||
from llama_stack.core.server.fastapi_router_registry import _ROUTER_FACTORIES
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.datatypes import Api
|
||||
from llama_stack_api.version import (
|
||||
LLAMA_STACK_API_V1,
|
||||
LLAMA_STACK_API_V1ALPHA,
|
||||
|
|
@ -35,20 +34,14 @@ class TracingMiddleware:
|
|||
"""Check if a path belongs to a router-based API.
|
||||
|
||||
Router-based APIs use FastAPI routers instead of the old webmethod system.
|
||||
We need to check if the path matches any router-based API prefix.
|
||||
Paths must start with a valid API level (v1, v1alpha, v1beta) followed by an API name.
|
||||
"""
|
||||
# Extract API name from path (e.g., /v1/batches -> batches)
|
||||
# Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name}
|
||||
parts = path.strip("/").split("/")
|
||||
if len(parts) >= 2 and parts[0] in VALID_API_LEVELS:
|
||||
api_name = parts[1]
|
||||
try:
|
||||
api = Api(api_name)
|
||||
return has_router(api)
|
||||
except (ValueError, KeyError):
|
||||
# Not a known API or not router-based
|
||||
return False
|
||||
return False
|
||||
if len(parts) < 2 or parts[0] not in VALID_API_LEVELS:
|
||||
return False
|
||||
|
||||
# Check directly if the API name is in the router factories list
|
||||
return parts[1] in _ROUTER_FACTORIES
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue