feat: convert Benchmarks API to use FastAPI router (#4309)

# What does this PR do?

Convert the Benchmarks API from @webmethod decorators to FastAPI router
pattern, matching the Batches API structure.

One notable change is the update of stack.py to handle request models in
register_resources().

Closes: #4308 

## Test Plan

CI and `curl http://localhost:8321/v1/inspect/routes | jq '.data[] |
select(.route | contains("benchmark"))'`

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-12-10 15:04:27 +01:00 committed by GitHub
parent 661985e240
commit ff375f1abb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 862 additions and 195 deletions

View file

@ -9,6 +9,7 @@ from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.core.datatypes import StackConfig
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.fastapi_router_registry import (
_ROUTER_FACTORIES,
@ -65,6 +66,17 @@ class DistributionInspectImpl(Inspect):
def _get_provider_types(api: Api) -> list[str]:
if api.value in ["providers", "inspect"]:
return [] # These APIs don't have "real" providers they're internal to the stack
# For routing table APIs, look up providers from their router API
# (e.g., benchmarks -> eval, models -> inference, etc.)
auto_routed_apis = builtin_automatically_routed_apis()
for auto_routed in auto_routed_apis:
if auto_routed.routing_table_api == api:
# This is a routing table API, use its router API for providers
providers = config.providers.get(auto_routed.router_api.value, [])
return [p.provider_type for p in providers] if providers else []
# Regular API, look up providers directly
providers = config.providers.get(api.value, [])
return [p.provider_type for p in providers] if providers else []

View file

@ -10,6 +10,7 @@ import json
import logging # allow-direct-logging
import os
import sys
import typing
from enum import Enum
from io import BytesIO
from pathlib import Path
@ -490,6 +491,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
unwrapped_body_param = param
break
# Check for parameters with Depends() annotation (FastAPI router endpoints)
# These need special handling: construct the request model from body
depends_param = None
for param in params_list:
param_type = param.annotation
if get_origin(param_type) is typing.Annotated:
args = get_args(param_type)
if len(args) > 1:
# Check if any metadata is Depends
metadata = args[1:]
for item in metadata:
# Check if it's a Depends object (has dependency attribute or is a callable)
# Depends objects typically have a 'dependency' attribute or are callable functions
if hasattr(item, "dependency") or callable(item) or "Depends" in str(type(item)):
depends_param = param
break
if depends_param:
break
# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
@ -500,6 +520,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
# Handle Depends parameter: construct request model from body
if depends_param and depends_param.name not in converted_body:
param_type = depends_param.annotation
if get_origin(param_type) is typing.Annotated:
base_type = get_args(param_type)[0]
# Handle Union types (e.g., SomeRequestModel | None) - extract the non-None type
# In Python 3.10+, Union types created with | syntax are still typing.Union
origin = get_origin(base_type)
if origin is Union:
# Get the first non-None type from the Union
union_args = get_args(base_type)
base_type = next(
(t for t in union_args if t is not type(None) and t is not None),
union_args[0] if union_args else None,
)
# Only try to instantiate if it's a class (not a Union or other non-callable type)
if base_type is not None and inspect.isclass(base_type) and callable(base_type):
# Construct the request model from all body parameters
converted_body[depends_param.name] = base_type(**body)
# handle unwrapped body parameter after processing all named parameters
if unwrapped_body_param:
base_type = get_args(unwrapped_body_param.annotation)[0]

View file

@ -4,13 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import (
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack_api import (
Benchmark,
Benchmarks,
GetBenchmarkRequest,
ListBenchmarksRequest,
ListBenchmarksResponse,
RegisterBenchmarkRequest,
UnregisterBenchmarkRequest,
)
from .common import CommonRoutingTableImpl
@ -18,26 +25,21 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
async def list_benchmarks(self, request: ListBenchmarksRequest) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
async def get_benchmark(self, request: GetBenchmarkRequest) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", request.benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
raise ValueError(f"Benchmark '{request.benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
request: RegisterBenchmarkRequest,
) -> None:
if metadata is None:
metadata = {}
metadata = request.metadata if request.metadata is not None else {}
provider_id = request.provider_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -45,18 +47,20 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
provider_benchmark_id = request.provider_benchmark_id
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
provider_benchmark_id = request.benchmark_id
benchmark = BenchmarkWithOwner(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
identifier=request.benchmark_id,
dataset_id=request.dataset_id,
scoring_functions=request.scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
async def unregister_benchmark(self, benchmark_id: str) -> None:
existing_benchmark = await self.get_benchmark(benchmark_id)
async def unregister_benchmark(self, request: UnregisterBenchmarkRequest) -> None:
get_request = GetBenchmarkRequest(benchmark_id=request.benchmark_id)
existing_benchmark = await self.get_benchmark(get_request)
await self.unregister_object(existing_benchmark)

View file

@ -17,7 +17,7 @@ from fastapi import APIRouter
from fastapi.routing import APIRoute
from starlette.routing import Route
from llama_stack_api import batches
from llama_stack_api import batches, benchmarks
# Router factories for APIs that have FastAPI routers
# Add new APIs here as they are migrated to the router system
@ -25,6 +25,7 @@ from llama_stack_api.datatypes import Api
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
"batches": batches.fastapi_routes.create_router,
"benchmarks": benchmarks.fastapi_routes.create_router,
}

View file

@ -13,6 +13,11 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.core.resolver import api_protocol_map
from llama_stack.core.server.fastapi_router_registry import (
_ROUTER_FACTORIES,
build_fastapi_router,
get_router_routes,
)
from llama_stack_api import Api, ExternalApiSpec, WebMethod
EndpointFunc = Callable[..., Any]
@ -85,7 +90,53 @@ def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | No
return f"^{pattern}$"
# Process routes from FastAPI routers
for api_name in _ROUTER_FACTORIES.keys():
api = Api(api_name)
if api not in impls:
continue
impl = impls[api]
router = build_fastapi_router(api, impl)
if router:
router_routes = get_router_routes(router)
for route in router_routes:
# Get the endpoint function from the route
# For FastAPI routes, the endpoint is the actual function
func = route.endpoint
if func is None:
continue
# Get the first (and typically only) method from the set, filtering out HEAD
available_methods = [m for m in (route.methods or []) if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
# Create a minimal WebMethod for router routes (needed for RouteMatch tuple)
# We don't have webmethod metadata for router routes, so create a minimal one
# that has the attributes used by the library client (descriptive_name for tracing)
#
# TODO: Long-term migration plan (once all APIs are migrated to FastAPI routers):
# - Extract summary from APIRoute: route.summary (available on FastAPI APIRoute objects)
# - Pass summary directly in RouteMatch instead of WebMethod
# - Remove this WebMethod() instantiation entirely
# - Update library_client.py to use the extracted summary instead of webmethod.descriptive_name
webmethod = WebMethod(descriptive_name=None)
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
route.path,
webmethod,
)
# Process routes from legacy webmethod-based APIs
for api, api_routes in api_to_routes.items():
# Skip APIs that have routers (already processed above)
if api.value in _ROUTER_FACTORIES:
continue
if api not in impls:
continue
for route, webmethod in api_routes:

View file

@ -6,12 +6,14 @@
import asyncio
import importlib.resources
import inspect
import os
import re
import tempfile
from typing import Any
from typing import Any, get_type_hints
import yaml
from pydantic import BaseModel
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackConfig, VectorStoresConfig
@ -108,6 +110,81 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
def is_request_model(t: Any) -> bool:
"""Check if a type is a request model (Pydantic BaseModel).
Args:
t: The type to check
Returns:
True if the type is a Pydantic BaseModel subclass, False otherwise
"""
return inspect.isclass(t) and issubclass(t, BaseModel)
async def invoke_with_optional_request(method: Any) -> Any:
"""Invoke a method, automatically creating a request instance if needed.
For APIs that use request models, this will create an empty request object.
For backward compatibility, falls back to calling without arguments.
Uses get_type_hints() to resolve forward references (e.g., "ListBenchmarksRequest" -> actual class).
Handles methods with:
- No parameters: calls without arguments
- One or more request model parameters: creates empty instances for each
- Mixed parameters: creates request models, uses defaults for others
- Required non-request-model parameters without defaults: falls back to calling without arguments
Args:
method: The method to invoke
Returns:
The result of calling the method
"""
try:
hints = get_type_hints(method)
except Exception:
# Forward references can't be resolved, fall back to calling without request
return await method()
params = list(inspect.signature(method).parameters.values())
params = [p for p in params if p.name != "self"]
if not params:
return await method()
# Build arguments for the method call
args: dict[str, Any] = {}
can_call = True
for param in params:
param_type = hints.get(param.name)
# If it's a request model, try to create an empty instance
if param_type and is_request_model(param_type):
try:
args[param.name] = param_type()
except Exception:
# Request model requires arguments, can't create empty instance
can_call = False
break
# If it has a default value, we can skip it (will use default)
elif param.default != inspect.Parameter.empty:
continue
# Required parameter that's not a request model - can't provide it
else:
can_call = False
break
if can_call and args:
return await method(**args)
# Fall back to calling without arguments for backward compatibility
return await method()
async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config.registered_resources, rsrc)
@ -129,7 +206,7 @@ async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
response = await invoke_with_optional_request(method)
objects_to_process = response.data if hasattr(response, "data") else response