mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: removed impl_getter from router function
Refactored the router to accept the implementation directly instead of using the impl_getter pattern. The caller already knows which API it's building a router for.for Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
8a21d8debe
commit
95e9455335
4 changed files with 12 additions and 25 deletions
|
|
@ -14,6 +14,7 @@ from typing import Any
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.resolver import api_protocol_map
|
||||||
|
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
||||||
from .state import _protocol_methods_cache
|
from .state import _protocol_methods_cache
|
||||||
|
|
@ -77,19 +78,11 @@ def create_llama_stack_app() -> FastAPI:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include routers for APIs that have them (automatic discovery)
|
# Include routers for APIs that have them (automatic discovery)
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
|
||||||
|
|
||||||
def dummy_impl_getter(api: Api) -> Any:
|
|
||||||
"""Dummy implementation getter for OpenAPI generation."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get all APIs that might have routers
|
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
for api in protocols.keys():
|
for api in protocols.keys():
|
||||||
if has_router(api):
|
if has_router(api):
|
||||||
router = build_router(api, dummy_impl_getter)
|
# For OpenAPI generation, we don't need a real implementation
|
||||||
|
router = build_router(api, None)
|
||||||
if router:
|
if router:
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -109,10 +109,6 @@ class DistributionInspectImpl(Inspect):
|
||||||
return not route_deprecated
|
return not route_deprecated
|
||||||
|
|
||||||
# Process router-based routes
|
# Process router-based routes
|
||||||
def dummy_impl_getter(api: Api) -> None:
|
|
||||||
"""Dummy implementation getter for route inspection."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.resolver import api_protocol_map
|
||||||
|
|
||||||
protocols = api_protocol_map(external_apis)
|
protocols = api_protocol_map(external_apis)
|
||||||
|
|
@ -120,7 +116,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
if not has_router(api):
|
if not has_router(api):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
router = build_router(api, dummy_impl_getter)
|
# For route inspection, we don't need a real implementation
|
||||||
|
router = build_router(api, None)
|
||||||
if not router:
|
if not router:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ Routers are automatically discovered by checking for routes modules in each API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
@ -36,15 +35,15 @@ def has_router(api: "Api") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None:
|
def build_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||||
"""Build a router for an API by combining its router factory with the implementation.
|
"""Build a router for an API by combining its router factory with the implementation.
|
||||||
|
|
||||||
This function discovers the router factory from the API package's routes module
|
This function discovers the router factory from the API package's routes module
|
||||||
and calls it with the impl_getter to create the final router instance.
|
and calls it with the implementation to create the final router instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api: The API enum value
|
api: The API enum value
|
||||||
impl_getter: Function that returns the implementation for a given API
|
impl: The implementation instance for the API
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
APIRouter if the API has a routes module with create_router, None otherwise
|
APIRouter if the API has a routes module with create_router, None otherwise
|
||||||
|
|
@ -53,7 +52,7 @@ def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter |
|
||||||
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes")
|
||||||
if hasattr(routes_module, "create_router"):
|
if hasattr(routes_module, "create_router"):
|
||||||
router_factory = routes_module.create_router
|
router_factory = routes_module.create_router
|
||||||
return router_factory(impl_getter)
|
return router_factory(impl)
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ FastAPI route decorators. The router is defined in the API package to keep
|
||||||
all API-related code together.
|
all API-related code together.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Path, Query
|
from fastapi import APIRouter, Body, Depends, Path, Query
|
||||||
|
|
@ -23,16 +22,15 @@ from llama_stack_api.batches.models import (
|
||||||
ListBatchesRequest,
|
ListBatchesRequest,
|
||||||
RetrieveBatchRequest,
|
RetrieveBatchRequest,
|
||||||
)
|
)
|
||||||
from llama_stack_api.datatypes import Api
|
|
||||||
from llama_stack_api.router_utils import standard_responses
|
from llama_stack_api.router_utils import standard_responses
|
||||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||||
|
|
||||||
|
|
||||||
def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
def create_router(impl: Batches) -> APIRouter:
|
||||||
"""Create a FastAPI router for the Batches API.
|
"""Create a FastAPI router for the Batches API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
impl_getter: Function that returns the Batches implementation for the batches API
|
impl: The Batches implementation instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
APIRouter configured for the Batches API
|
APIRouter configured for the Batches API
|
||||||
|
|
@ -45,7 +43,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
||||||
|
|
||||||
def get_batch_service() -> Batches:
|
def get_batch_service() -> Batches:
|
||||||
"""Dependency function to get the batch service implementation."""
|
"""Dependency function to get the batch service implementation."""
|
||||||
return impl_getter(Api.batches)
|
return impl
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/batches",
|
"/batches",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue