mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
chore: removed impl_getter from router function
Refactored the router to accept the implementation directly instead of using the impl_getter pattern. The caller already knows which API it's building a router for.for Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
8a21d8debe
commit
95e9455335
4 changed files with 12 additions and 25 deletions
|
|
@ -14,6 +14,7 @@ from typing import Any
|
|||
from fastapi import FastAPI
|
||||
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .state import _protocol_methods_cache
|
||||
|
|
@ -77,19 +78,11 @@ def create_llama_stack_app() -> FastAPI:
|
|||
)
|
||||
|
||||
# Include routers for APIs that have them (automatic discovery)
|
||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||
|
||||
def dummy_impl_getter(api: Api) -> Any:
|
||||
"""Dummy implementation getter for OpenAPI generation."""
|
||||
return None
|
||||
|
||||
# Get all APIs that might have routers
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
|
||||
protocols = api_protocol_map()
|
||||
for api in protocols.keys():
|
||||
if has_router(api):
|
||||
router = build_router(api, dummy_impl_getter)
|
||||
# For OpenAPI generation, we don't need a real implementation
|
||||
router = build_router(api, None)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
|
||||
|
|
|
|||
|
|
@ -109,10 +109,6 @@ class DistributionInspectImpl(Inspect):
|
|||
return not route_deprecated
|
||||
|
||||
# Process router-based routes
|
||||
def dummy_impl_getter(api: Api) -> None:
|
||||
"""Dummy implementation getter for route inspection."""
|
||||
return None
|
||||
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
|
|
@ -120,7 +116,8 @@ class DistributionInspectImpl(Inspect):
|
|||
if not has_router(api):
|
||||
continue
|
||||
|
||||
router = build_router(api, dummy_impl_getter)
|
||||
# For route inspection, we don't need a real implementation
|
||||
router = build_router(api, None)
|
||||
if not router:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ Routers are automatically discovered by checking for routes modules in each API
|
|||
"""
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
|
@ -36,15 +35,15 @@ def has_router(api: "Api") -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None:
|
||||
def build_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||
"""Build a router for an API by combining its router factory with the implementation.
|
||||
|
||||
This function discovers the router factory from the API package's routes module
|
||||
and calls it with the impl_getter to create the final router instance.
|
||||
and calls it with the implementation to create the final router instance.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
impl_getter: Function that returns the implementation for a given API
|
||||
impl: The implementation instance for the API
|
||||
|
||||
Returns:
|
||||
APIRouter if the API has a routes module with create_router, None otherwise
|
||||
|
|
@ -53,7 +52,7 @@ def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter |
|
|||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
||||
if hasattr(routes_module, "create_router"):
|
||||
router_factory = routes_module.create_router
|
||||
return router_factory(impl_getter)
|
||||
return router_factory(impl)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ FastAPI route decorators. The router is defined in the API package to keep
|
|||
all API-related code together.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Path, Query
|
||||
|
|
@ -23,16 +22,15 @@ from llama_stack_api.batches.models import (
|
|||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.datatypes import Api
|
||||
from llama_stack_api.router_utils import standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
|
||||
def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
||||
def create_router(impl: Batches) -> APIRouter:
|
||||
"""Create a FastAPI router for the Batches API.
|
||||
|
||||
Args:
|
||||
impl_getter: Function that returns the Batches implementation for the batches API
|
||||
impl: The Batches implementation instance
|
||||
|
||||
Returns:
|
||||
APIRouter configured for the Batches API
|
||||
|
|
@ -45,7 +43,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
|
||||
def get_batch_service() -> Batches:
|
||||
"""Dependency function to get the batch service implementation."""
|
||||
return impl_getter(Api.batches)
|
||||
return impl
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue