mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 10:58:41 +00:00
feat: convert Datasets API to use FastAPI router (#4359)
# What does this PR do? Convert the Datasets API from webmethod decorators to FastAPI router pattern. Fixes: https://github.com/llamastack/llama-stack/issues/4344 ## Test Plan CI Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
56f946f3f5
commit
700663028f
13 changed files with 716 additions and 335 deletions
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
|
|
@ -14,15 +13,18 @@ from llama_stack.log import get_logger
|
|||
from llama_stack_api import (
|
||||
Dataset,
|
||||
DatasetNotFoundError,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
ResourceType,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack_api.datasets.api import (
|
||||
Datasets,
|
||||
GetDatasetRequest,
|
||||
RegisterDatasetRequest,
|
||||
UnregisterDatasetRequest,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
@ -33,19 +35,17 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
async def get_dataset(self, request: GetDatasetRequest) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", request.dataset_id)
|
||||
if dataset is None:
|
||||
raise DatasetNotFoundError(dataset_id)
|
||||
raise DatasetNotFoundError(request.dataset_id)
|
||||
return dataset
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
async def register_dataset(self, request: RegisterDatasetRequest) -> Dataset:
|
||||
purpose = request.purpose
|
||||
source = request.source
|
||||
metadata = request.metadata
|
||||
dataset_id = request.dataset_id
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
|
|
@ -86,6 +86,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
await self.register_object(dataset)
|
||||
return dataset
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
async def unregister_dataset(self, request: UnregisterDatasetRequest) -> None:
|
||||
dataset = await self.get_dataset(GetDatasetRequest(dataset_id=request.dataset_id))
|
||||
await self.unregister_object(dataset)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from fastapi import APIRouter
|
|||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack_api import batches, benchmarks
|
||||
from llama_stack_api import batches, benchmarks, datasets
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
|
|
@ -26,6 +26,7 @@ from llama_stack_api.datatypes import Api
|
|||
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
|
||||
"batches": batches.fastapi_routes.create_router,
|
||||
"benchmarks": benchmarks.fastapi_routes.create_router,
|
||||
"datasets": datasets.fastapi_routes.create_router,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,248 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1BETA
|
||||
|
||||
|
||||
class DatasetPurpose(StrEnum):
|
||||
"""
|
||||
Purpose of the dataset. Each purpose has a required input data schema.
|
||||
|
||||
:cvar post-training/messages: The dataset contains messages used for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
:cvar eval/question-answer: The dataset contains a question column and an answer column.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
"""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
eval_question_answer = "eval/question-answer"
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
|
||||
# TODO: add more schemas here
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""
|
||||
Type of the dataset source.
|
||||
:cvar uri: The dataset can be obtained from a URI.
|
||||
:cvar rows: The dataset is stored in rows.
|
||||
"""
|
||||
|
||||
uri = "uri"
|
||||
rows = "rows"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI.
|
||||
:param uri: The dataset can be obtained from a URI. E.g.
|
||||
- "https://mywebsite.com/mydata.jsonl"
|
||||
- "lsfs://mydata.jsonl"
|
||||
- "data:csv;base64,{base64_content}"
|
||||
"""
|
||||
|
||||
type: Literal["uri"] = "uri"
|
||||
uri: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows.
|
||||
:param rows: The dataset is stored in rows. E.g.
|
||||
- [
|
||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||
]
|
||||
"""
|
||||
|
||||
type: Literal["rows"] = "rows"
|
||||
rows: list[dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
URIDataSource | RowsDataSource,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
"""
|
||||
Common fields for a dataset.
|
||||
|
||||
:param purpose: Purpose of the dataset indicating its intended use
|
||||
:param source: Data source configuration for the dataset
|
||||
:param metadata: Additional metadata for the dataset
|
||||
"""
|
||||
|
||||
purpose: DatasetPurpose
|
||||
source: DataSource
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
"""Dataset resource for storing and accessing training or evaluation data.
|
||||
|
||||
:param type: Type of resource, always 'dataset' for datasets
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
"""Input parameters for dataset operations.
|
||||
|
||||
:param dataset_id: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
"""Response from listing datasets.
|
||||
|
||||
:param data: List of datasets
|
||||
"""
|
||||
|
||||
data: list[Dataset]
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset.
|
||||
One of:
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
}
|
||||
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, my name is John Doe."},
|
||||
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
|
||||
{"role": "user", "content": "What's my name?"},
|
||||
],
|
||||
"answer": "John Doe"
|
||||
}
|
||||
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "lsfs://mydata.jsonl"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "data:csv;base64,{base64_content}"
|
||||
}
|
||||
- {
|
||||
"type": "uri",
|
||||
"uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
}
|
||||
- {
|
||||
"type": "rows",
|
||||
"rows": [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
{"role": "assistant", "content": "Hello, world!"},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}.
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset:
|
||||
"""Get a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
:returns: A ListDatasetsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
"""Unregister a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to unregister.
|
||||
"""
|
||||
...
|
||||
61
src/llama_stack_api/datasets/__init__.py
Normal file
61
src/llama_stack_api/datasets/__init__.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Datasets API protocol and models.
|
||||
|
||||
This module contains the Datasets protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.datasets.models.
|
||||
The FastAPI router is defined in llama_stack_api.datasets.fastapi_routes.
|
||||
"""
|
||||
|
||||
# Import fastapi_routes for router factory access
|
||||
from . import fastapi_routes
|
||||
|
||||
# Import new protocol for FastAPI router
|
||||
from .api import Datasets
|
||||
|
||||
# Import models for re-export
|
||||
from .models import (
|
||||
CommonDatasetFields,
|
||||
Dataset,
|
||||
DatasetPurpose,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
GetDatasetRequest,
|
||||
ListDatasetsResponse,
|
||||
RegisterDatasetRequest,
|
||||
RowsDataSource,
|
||||
UnregisterDatasetRequest,
|
||||
URIDataSource,
|
||||
)
|
||||
|
||||
|
||||
# Define DatasetInput for backward compatibility
|
||||
class DatasetInput(CommonDatasetFields):
|
||||
"""Input parameters for dataset operations.
|
||||
|
||||
:param dataset_id: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
dataset_id: str
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Datasets",
|
||||
"Dataset",
|
||||
"CommonDatasetFields",
|
||||
"DatasetPurpose",
|
||||
"DataSource",
|
||||
"DatasetInput",
|
||||
"DatasetType",
|
||||
"RowsDataSource",
|
||||
"URIDataSource",
|
||||
"ListDatasetsResponse",
|
||||
"RegisterDatasetRequest",
|
||||
"GetDatasetRequest",
|
||||
"UnregisterDatasetRequest",
|
||||
"fastapi_routes",
|
||||
]
|
||||
35
src/llama_stack_api/datasets/api.py
Normal file
35
src/llama_stack_api/datasets/api.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Datasets API protocol definition.
|
||||
|
||||
This module contains the Datasets protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.datasets.models.
|
||||
The FastAPI router is defined in llama_stack_api.datasets.fastapi_routes.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from .models import (
|
||||
Dataset,
|
||||
GetDatasetRequest,
|
||||
ListDatasetsResponse,
|
||||
RegisterDatasetRequest,
|
||||
UnregisterDatasetRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Datasets(Protocol):
|
||||
"""Protocol for dataset management operations."""
|
||||
|
||||
async def register_dataset(self, request: RegisterDatasetRequest) -> Dataset: ...
|
||||
|
||||
async def get_dataset(self, request: GetDatasetRequest) -> Dataset: ...
|
||||
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
|
||||
async def unregister_dataset(self, request: UnregisterDatasetRequest) -> None: ...
|
||||
104
src/llama_stack_api/datasets/fastapi_routes.py
Normal file
104
src/llama_stack_api/datasets/fastapi_routes.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""FastAPI router for the Datasets API.
|
||||
|
||||
This module defines the FastAPI router for the Datasets API using standard
|
||||
FastAPI route decorators.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from llama_stack_api.router_utils import create_path_dependency, standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1BETA
|
||||
|
||||
from .api import Datasets
|
||||
from .models import (
|
||||
Dataset,
|
||||
GetDatasetRequest,
|
||||
ListDatasetsResponse,
|
||||
RegisterDatasetRequest,
|
||||
UnregisterDatasetRequest,
|
||||
)
|
||||
|
||||
# Path parameter dependencies for single-field models
|
||||
get_dataset_request = create_path_dependency(GetDatasetRequest)
|
||||
unregister_dataset_request = create_path_dependency(UnregisterDatasetRequest)
|
||||
|
||||
|
||||
def create_router(impl: Datasets) -> APIRouter:
|
||||
"""Create a FastAPI router for the Datasets API.
|
||||
|
||||
Args:
|
||||
impl: The Datasets implementation instance
|
||||
|
||||
Returns:
|
||||
APIRouter configured for the Datasets API
|
||||
"""
|
||||
router = APIRouter(
|
||||
prefix=f"/{LLAMA_STACK_API_V1BETA}",
|
||||
tags=["Datasets"],
|
||||
responses=standard_responses,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/datasets",
|
||||
response_model=Dataset,
|
||||
summary="Register a new dataset.",
|
||||
description="Register a new dataset.",
|
||||
responses={
|
||||
200: {"description": "The registered dataset object."},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def register_dataset(
|
||||
request: Annotated[RegisterDatasetRequest, Body(...)],
|
||||
) -> Dataset:
|
||||
return await impl.register_dataset(request)
|
||||
|
||||
@router.get(
|
||||
"/datasets/{dataset_id:path}",
|
||||
response_model=Dataset,
|
||||
summary="Get a dataset by its ID.",
|
||||
description="Get a dataset by its ID.",
|
||||
responses={
|
||||
200: {"description": "The dataset object."},
|
||||
},
|
||||
)
|
||||
async def get_dataset(
|
||||
request: Annotated[GetDatasetRequest, Depends(get_dataset_request)],
|
||||
) -> Dataset:
|
||||
return await impl.get_dataset(request)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=ListDatasetsResponse,
|
||||
summary="List all datasets.",
|
||||
description="List all datasets.",
|
||||
responses={
|
||||
200: {"description": "A list of dataset objects."},
|
||||
},
|
||||
)
|
||||
async def list_datasets() -> ListDatasetsResponse:
|
||||
return await impl.list_datasets()
|
||||
|
||||
@router.delete(
|
||||
"/datasets/{dataset_id:path}",
|
||||
summary="Unregister a dataset by its ID.",
|
||||
description="Unregister a dataset by its ID.",
|
||||
responses={
|
||||
200: {"description": "The dataset was successfully unregistered."},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def unregister_dataset(
|
||||
request: Annotated[UnregisterDatasetRequest, Depends(unregister_dataset_request)],
|
||||
) -> None:
|
||||
return await impl.unregister_dataset(request)
|
||||
|
||||
return router
|
||||
152
src/llama_stack_api/datasets/models.py
Normal file
152
src/llama_stack_api/datasets/models.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Datasets API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Datasets API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
class DatasetPurpose(StrEnum):
|
||||
"""Purpose of the dataset. Each purpose has a required input data schema."""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
"""The dataset contains messages used for post-training."""
|
||||
eval_question_answer = "eval/question-answer"
|
||||
"""The dataset contains a question column and an answer column."""
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
"""The dataset contains a messages column with list of messages and an answer column."""
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""Type of the dataset source."""
|
||||
|
||||
uri = "uri"
|
||||
"""The dataset can be obtained from a URI."""
|
||||
rows = "rows"
|
||||
"""The dataset is stored in rows."""
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI."""
|
||||
|
||||
type: Literal["uri"] = Field(default="uri", description="The type of data source.")
|
||||
uri: str = Field(
|
||||
...,
|
||||
description='The dataset can be obtained from a URI. E.g. "https://mywebsite.com/mydata.jsonl", "lsfs://mydata.jsonl", "data:csv;base64,{base64_content}"',
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows."""
|
||||
|
||||
type: Literal["rows"] = Field(default="rows", description="The type of data source.")
|
||||
rows: list[dict[str, Any]] = Field(
|
||||
...,
|
||||
description='The dataset is stored in rows. E.g. [{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}]',
|
||||
)
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
URIDataSource | RowsDataSource,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
"""Common fields for a dataset."""
|
||||
|
||||
purpose: DatasetPurpose = Field(..., description="Purpose of the dataset indicating its intended use")
|
||||
source: DataSource = Field(..., description="Data source configuration for the dataset")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
"""Dataset resource for storing and accessing training or evaluation data."""
|
||||
|
||||
type: Literal[ResourceType.dataset] = Field(
|
||||
default=ResourceType.dataset,
|
||||
description="Type of resource, always 'dataset' for datasets",
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
"""Response from listing datasets."""
|
||||
|
||||
data: list[Dataset] = Field(..., description="List of datasets")
|
||||
|
||||
|
||||
# Request models for each endpoint
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegisterDatasetRequest(BaseModel):
|
||||
"""Request model for registering a dataset."""
|
||||
|
||||
purpose: DatasetPurpose = Field(..., description="The purpose of the dataset.")
|
||||
source: DataSource = Field(..., description="The data source of the dataset.")
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="The metadata for the dataset.",
|
||||
)
|
||||
dataset_id: str | None = Field(
|
||||
default=None,
|
||||
description="The ID of the dataset. If not provided, an ID will be generated.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GetDatasetRequest(BaseModel):
|
||||
"""Request model for getting a dataset by ID."""
|
||||
|
||||
dataset_id: str = Field(..., description="The ID of the dataset to get.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnregisterDatasetRequest(BaseModel):
|
||||
"""Request model for unregistering a dataset."""
|
||||
|
||||
dataset_id: str = Field(..., description="The ID of the dataset to unregister.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CommonDatasetFields",
|
||||
"Dataset",
|
||||
"DatasetPurpose",
|
||||
"DatasetType",
|
||||
"DataSource",
|
||||
"RowsDataSource",
|
||||
"URIDataSource",
|
||||
"ListDatasetsResponse",
|
||||
"RegisterDatasetRequest",
|
||||
"GetDatasetRequest",
|
||||
"UnregisterDatasetRequest",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue