mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: use hardcoded list and dictionary mapping for router registry
Replace dynamic import-based router discovery with an explicit hardcoded list of APIs that have routers. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
03a31269ad
commit
49005f1a39
5 changed files with 39 additions and 68 deletions
|
|
@ -14,7 +14,7 @@ from typing import Any
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.resolver import api_protocol_map
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
||||||
from .state import _protocol_methods_cache
|
from .state import _protocol_methods_cache
|
||||||
|
|
@ -77,10 +77,9 @@ def create_llama_stack_app() -> FastAPI:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include routers for APIs that have them (automatic discovery)
|
# Include routers for APIs that have them
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
for api in protocols.keys():
|
for api in protocols.keys():
|
||||||
if has_router(api):
|
|
||||||
# For OpenAPI generation, we don't need a real implementation
|
# For OpenAPI generation, we don't need a real implementation
|
||||||
router = build_router(api, None)
|
router = build_router(api, None)
|
||||||
if router:
|
if router:
|
||||||
|
|
@ -96,7 +95,7 @@ def create_llama_stack_app() -> FastAPI:
|
||||||
|
|
||||||
for api, routes in api_routes.items():
|
for api, routes in api_routes.items():
|
||||||
# Skip APIs that have routers - they're already included above
|
# Skip APIs that have routers - they're already included above
|
||||||
if has_router(api):
|
if build_router(api, None) is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for route, webmethod in routes:
|
for route, webmethod in routes:
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Api,
|
Api,
|
||||||
|
|
@ -70,7 +70,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
# Process webmethod-based routes (legacy)
|
# Process webmethod-based routes (legacy)
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
# Skip APIs that have routers - they'll be processed separately
|
# Skip APIs that have routers - they'll be processed separately
|
||||||
if has_router(api):
|
if build_router(api, None) is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
provider_types = get_provider_types(api)
|
provider_types = get_provider_types(api)
|
||||||
|
|
@ -113,9 +113,6 @@ class DistributionInspectImpl(Inspect):
|
||||||
|
|
||||||
protocols = api_protocol_map(external_apis)
|
protocols = api_protocol_map(external_apis)
|
||||||
for api in protocols.keys():
|
for api in protocols.keys():
|
||||||
if not has_router(api):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# For route inspection, we don't need a real implementation
|
# For route inspection, we don't need a real implementation
|
||||||
router = build_router(api, None)
|
router = build_router(api, None)
|
||||||
if not router:
|
if not router:
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
"""Router utilities for FastAPI routers.
|
"""Router utilities for FastAPI routers.
|
||||||
|
|
||||||
This module provides utilities to discover and create FastAPI routers from API packages.
|
This module provides utilities to create FastAPI routers from API packages.
|
||||||
Routers are automatically discovered by checking for fastapi_routes modules in each API package.
|
APIs with routers are explicitly listed here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
@ -18,46 +17,30 @@ from fastapi import APIRouter
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llama_stack_api.datatypes import Api
|
from llama_stack_api.datatypes import Api
|
||||||
|
|
||||||
|
# Router factories for APIs that have FastAPI routers
|
||||||
|
# Add new APIs here as they are migrated to the router system
|
||||||
|
from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router
|
||||||
|
|
||||||
def has_router(api: "Api") -> bool:
|
_ROUTER_FACTORIES: dict[str, APIRouter] = {
|
||||||
"""Check if an API has a router factory in its fastapi_routes module.
|
"batches": create_batches_router,
|
||||||
|
}
|
||||||
Args:
|
|
||||||
api: The API enum value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the API has a fastapi_routes module with a create_router function
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
|
||||||
return hasattr(routes_module, "create_router")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def build_router(api: "Api", impl: Any) -> APIRouter | None:
|
def build_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||||
"""Build a router for an API by combining its router factory with the implementation.
|
"""Build a router for an API by combining its router factory with the implementation.
|
||||||
|
|
||||||
This function discovers the router factory from the API package's routes module
|
|
||||||
and calls it with the implementation to create the final router instance.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api: The API enum value
|
api: The API enum value
|
||||||
impl: The implementation instance for the API
|
impl: The implementation instance for the API
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
APIRouter if the API has a fastapi_routes module with create_router, None otherwise
|
APIRouter if the API has a router factory, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
router_factory = _ROUTER_FACTORIES.get(api.value)
|
||||||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
if router_factory is None:
|
||||||
if hasattr(routes_module, "create_router"):
|
|
||||||
router_factory = routes_module.create_router
|
|
||||||
# cast is safe here: mypy can't verify the return type statically because
|
|
||||||
# we're dynamically importing the module. However, all router factories in
|
|
||||||
# API packages are required to return APIRouter. If a router factory returns the wrong
|
|
||||||
# type, it will fail at runtime when app.include_router(router) is called
|
|
||||||
return cast(APIRouter, router_factory(impl))
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# cast is safe here: all router factories in API packages are required to return APIRouter.
|
||||||
|
# If a router factory returns the wrong type, it will fail at runtime when
|
||||||
|
# app.include_router(router) is called
|
||||||
|
return cast(APIRouter, router_factory(impl))
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
from llama_stack.core.server.fastapi_router_registry import build_router
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
Stack,
|
Stack,
|
||||||
|
|
@ -469,7 +469,6 @@ def create_app() -> StackApp:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
# Try to discover and use a router factory from the API package
|
# Try to discover and use a router factory from the API package
|
||||||
if has_router(api):
|
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
router = build_router(api, impl)
|
router = build_router(api, impl)
|
||||||
if router:
|
if router:
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,10 @@
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
|
|
||||||
from llama_stack.core.external import ExternalApiSpec
|
from llama_stack.core.external import ExternalApiSpec
|
||||||
from llama_stack.core.server.fastapi_router_registry import has_router
|
from llama_stack.core.server.fastapi_router_registry import _ROUTER_FACTORIES
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api.datatypes import Api
|
|
||||||
from llama_stack_api.version import (
|
from llama_stack_api.version import (
|
||||||
LLAMA_STACK_API_V1,
|
LLAMA_STACK_API_V1,
|
||||||
LLAMA_STACK_API_V1ALPHA,
|
LLAMA_STACK_API_V1ALPHA,
|
||||||
|
|
@ -35,21 +34,15 @@ class TracingMiddleware:
|
||||||
"""Check if a path belongs to a router-based API.
|
"""Check if a path belongs to a router-based API.
|
||||||
|
|
||||||
Router-based APIs use FastAPI routers instead of the old webmethod system.
|
Router-based APIs use FastAPI routers instead of the old webmethod system.
|
||||||
We need to check if the path matches any router-based API prefix.
|
Paths must start with a valid API level (v1, v1alpha, v1beta) followed by an API name.
|
||||||
"""
|
"""
|
||||||
# Extract API name from path (e.g., /v1/batches -> batches)
|
|
||||||
# Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name}
|
|
||||||
parts = path.strip("/").split("/")
|
parts = path.strip("/").split("/")
|
||||||
if len(parts) >= 2 and parts[0] in VALID_API_LEVELS:
|
if len(parts) < 2 or parts[0] not in VALID_API_LEVELS:
|
||||||
api_name = parts[1]
|
|
||||||
try:
|
|
||||||
api = Api(api_name)
|
|
||||||
return has_router(api)
|
|
||||||
except (ValueError, KeyError):
|
|
||||||
# Not a known API or not router-based
|
|
||||||
return False
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Check directly if the API name is in the router factories list
|
||||||
|
return parts[1] in _ROUTER_FACTORIES
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
if scope.get("type") == "lifespan":
|
if scope.get("type") == "lifespan":
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue