llama-stack-mirror/src/llama_stack/apis/benchmarks/routes.py
Sébastien Han 8df9340dd3
wiprouters
Signed-off-by: Sébastien Han <seb@redhat.com>
2025-11-04 18:09:11 +01:00

144 lines
4.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from fastapi import Body, Depends, Request
from fastapi import Path as FastAPIPath
from llama_stack.apis.datatypes import Api
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.core.server.router_utils import standard_responses
from llama_stack.core.server.routers import APIRouter, register_router
from .benchmarks_service import BenchmarksService
from .models import Benchmark, ListBenchmarksResponse, RegisterBenchmarkRequest
def get_benchmarks_service(request: Request) -> BenchmarksService:
"""Dependency to get the benchmarks service implementation from app state."""
impls = getattr(request.app.state, "impls", {})
if Api.benchmarks not in impls:
raise ValueError("Benchmarks API implementation not found")
return impls[Api.benchmarks]
router = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1}",
tags=["Benchmarks"],
responses=standard_responses,
)
router_v1alpha = APIRouter(
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
tags=["Benchmarks"],
responses=standard_responses,
)
@router.get(
"/eval/benchmarks",
response_model=ListBenchmarksResponse,
summary="List all benchmarks",
description="List all benchmarks",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks",
response_model=ListBenchmarksResponse,
summary="List all benchmarks",
description="List all benchmarks",
)
async def list_benchmarks(svc: BenchmarksService = Depends(get_benchmarks_service)) -> ListBenchmarksResponse:
"""List all benchmarks."""
return await svc.list_benchmarks()
@router.get(
"/eval/benchmarks/{benchmark_id}",
response_model=Benchmark,
summary="Get a benchmark by its ID",
description="Get a benchmark by its ID",
deprecated=True,
)
@router_v1alpha.get(
"/eval/benchmarks/{{benchmark_id}}",
response_model=Benchmark,
summary="Get a benchmark by its ID",
description="Get a benchmark by its ID",
)
async def get_benchmark(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to get")],
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> Benchmark:
"""Get a benchmark by its ID."""
return await svc.get_benchmark(benchmark_id=benchmark_id)
@router.post(
"/eval/benchmarks",
response_model=None,
status_code=204,
summary="Register a benchmark",
description="Register a benchmark",
deprecated=True,
)
@router_v1alpha.post(
"/eval/benchmarks",
response_model=None,
status_code=204,
summary="Register a benchmark",
description="Register a benchmark",
)
async def register_benchmark(
body: RegisterBenchmarkRequest = Body(...),
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> None:
"""Register a benchmark."""
return await svc.register_benchmark(
benchmark_id=body.benchmark_id,
dataset_id=body.dataset_id,
scoring_functions=body.scoring_functions,
provider_benchmark_id=body.provider_benchmark_id,
provider_id=body.provider_id,
metadata=body.metadata,
)
@router.delete(
"/eval/benchmarks/{benchmark_id}",
response_model=None,
status_code=204,
summary="Unregister a benchmark",
description="Unregister a benchmark",
deprecated=True,
)
@router_v1alpha.delete(
"/eval/benchmarks/{{benchmark_id}}",
response_model=None,
status_code=204,
summary="Unregister a benchmark",
description="Unregister a benchmark",
)
async def unregister_benchmark(
benchmark_id: Annotated[str, FastAPIPath(..., description="The ID of the benchmark to unregister")],
svc: BenchmarksService = Depends(get_benchmarks_service),
) -> None:
"""Unregister a benchmark."""
await svc.unregister_benchmark(benchmark_id=benchmark_id)
# For backward compatibility with the router registry system
def create_benchmarks_router(impl_getter) -> APIRouter:
"""Create a FastAPI router for the Benchmarks API (legacy compatibility)."""
main_router = APIRouter()
main_router.include_router(router)
main_router.include_router(router_v1alpha)
return main_router
# Register the router factory
register_router(Api.benchmarks, create_benchmarks_router)