feat: convert Datasets API to use FastAPI router (#4359)

# What does this PR do?

Convert the Datasets API from webmethod decorators to FastAPI router
pattern.

Fixes: https://github.com/llamastack/llama-stack/issues/4344

## Test Plan
CI

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-12-15 20:23:04 +01:00 committed by GitHub
parent 56f946f3f5
commit 700663028f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 716 additions and 335 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import uuid
from typing import Any
from llama_stack.core.datatypes import (
DatasetWithOwner,
@ -14,15 +13,18 @@ from llama_stack.log import get_logger
from llama_stack_api import (
Dataset,
DatasetNotFoundError,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
ResourceType,
RowsDataSource,
URIDataSource,
)
from llama_stack_api.datasets.api import (
Datasets,
GetDatasetRequest,
RegisterDatasetRequest,
UnregisterDatasetRequest,
)
from .common import CommonRoutingTableImpl
@ -33,19 +35,17 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
async def get_dataset(self, request: GetDatasetRequest) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", request.dataset_id)
if dataset is None:
raise DatasetNotFoundError(dataset_id)
raise DatasetNotFoundError(request.dataset_id)
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
async def register_dataset(self, request: RegisterDatasetRequest) -> Dataset:
purpose = request.purpose
source = request.source
metadata = request.metadata
dataset_id = request.dataset_id
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
@ -86,6 +86,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
async def unregister_dataset(self, request: UnregisterDatasetRequest) -> None:
dataset = await self.get_dataset(GetDatasetRequest(dataset_id=request.dataset_id))
await self.unregister_object(dataset)

View file

@ -17,7 +17,7 @@ from fastapi import APIRouter
from fastapi.routing import APIRoute
from starlette.routing import Route
from llama_stack_api import batches, benchmarks
from llama_stack_api import batches, benchmarks, datasets
# Router factories for APIs that have FastAPI routers
# Add new APIs here as they are migrated to the router system
@ -26,6 +26,7 @@ from llama_stack_api.datatypes import Api
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
"batches": batches.fastapi_routes.create_router,
"benchmarks": benchmarks.fastapi_routes.create_router,
"datasets": datasets.fastapi_routes.create_router,
}