mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 01:42:27 +00:00
feat: convert Datasets API to use FastAPI router (#4359)
# What does this PR do? Convert the Datasets API from webmethod decorators to FastAPI router pattern. Fixes: https://github.com/llamastack/llama-stack/issues/4344 ## Test Plan CI Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
56f946f3f5
commit
700663028f
13 changed files with 716 additions and 335 deletions
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
|
|
@ -14,15 +13,18 @@ from llama_stack.log import get_logger
|
|||
from llama_stack_api import (
|
||||
Dataset,
|
||||
DatasetNotFoundError,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
ResourceType,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack_api.datasets.api import (
|
||||
Datasets,
|
||||
GetDatasetRequest,
|
||||
RegisterDatasetRequest,
|
||||
UnregisterDatasetRequest,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
@ -33,19 +35,17 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
async def get_dataset(self, request: GetDatasetRequest) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", request.dataset_id)
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(dataset_id)
|
||||
raise DatasetNotFoundError(request.dataset_id)
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
async def register_dataset(self, request: RegisterDatasetRequest) -> Dataset:
|
||||
purpose = request.purpose
|
||||
source = request.source
|
||||
metadata = request.metadata
|
||||
dataset_id = request.dataset_id
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
|
|
@ -86,6 +86,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
await self.register_object(dataset)
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
async def unregister_dataset(self, request: UnregisterDatasetRequest) -> None:
|
||||
dataset = await self.get_dataset(GetDatasetRequest(dataset_id=request.dataset_id))
|
||||
await self.unregister_object(dataset)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from fastapi import APIRouter
|
|||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack_api import batches, benchmarks
|
||||
from llama_stack_api import batches, benchmarks, datasets
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
|
|
@ -26,6 +26,7 @@ from llama_stack_api.datatypes import Api
|
|||
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
|
||||
"batches": batches.fastapi_routes.create_router,
|
||||
"benchmarks": benchmarks.fastapi_routes.create_router,
|
||||
"datasets": datasets.fastapi_routes.create_router,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue