mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 23:02:25 +00:00
# What does this PR do? Convert the Datasets API from webmethod decorators to FastAPI router pattern. Fixes: https://github.com/llamastack/llama-stack/issues/4344 ## Test Plan CI Signed-off-by: Sébastien Han <seb@redhat.com>
152 lines
4.5 KiB
Python
152 lines
4.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""Pydantic models for Datasets API requests and responses.
|
|
|
|
This module defines the request and response models for the Datasets API
|
|
using Pydantic with Field descriptions for OpenAPI schema generation.
|
|
"""
|
|
|
|
from enum import Enum, StrEnum
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack_api.resource import Resource, ResourceType
|
|
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
|
|
|
|
|
class DatasetPurpose(StrEnum):
|
|
"""Purpose of the dataset. Each purpose has a required input data schema."""
|
|
|
|
post_training_messages = "post-training/messages"
|
|
"""The dataset contains messages used for post-training."""
|
|
eval_question_answer = "eval/question-answer"
|
|
"""The dataset contains a question column and an answer column."""
|
|
eval_messages_answer = "eval/messages-answer"
|
|
"""The dataset contains a messages column with list of messages and an answer column."""
|
|
|
|
|
|
class DatasetType(Enum):
|
|
"""Type of the dataset source."""
|
|
|
|
uri = "uri"
|
|
"""The dataset can be obtained from a URI."""
|
|
rows = "rows"
|
|
"""The dataset is stored in rows."""
|
|
|
|
|
|
@json_schema_type
|
|
class URIDataSource(BaseModel):
|
|
"""A dataset that can be obtained from a URI."""
|
|
|
|
type: Literal["uri"] = Field(default="uri", description="The type of data source.")
|
|
uri: str = Field(
|
|
...,
|
|
description='The dataset can be obtained from a URI. E.g. "https://mywebsite.com/mydata.jsonl", "lsfs://mydata.jsonl", "data:csv;base64,{base64_content}"',
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class RowsDataSource(BaseModel):
|
|
"""A dataset stored in rows."""
|
|
|
|
type: Literal["rows"] = Field(default="rows", description="The type of data source.")
|
|
rows: list[dict[str, Any]] = Field(
|
|
...,
|
|
description='The dataset is stored in rows. E.g. [{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}]',
|
|
)
|
|
|
|
|
|
DataSource = Annotated[
|
|
URIDataSource | RowsDataSource,
|
|
Field(discriminator="type"),
|
|
]
|
|
register_schema(DataSource, name="DataSource")
|
|
|
|
|
|
class CommonDatasetFields(BaseModel):
|
|
"""Common fields for a dataset."""
|
|
|
|
purpose: DatasetPurpose = Field(..., description="Purpose of the dataset indicating its intended use")
|
|
source: DataSource = Field(..., description="Data source configuration for the dataset")
|
|
metadata: dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="Any additional metadata for this dataset",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class Dataset(CommonDatasetFields, Resource):
|
|
"""Dataset resource for storing and accessing training or evaluation data."""
|
|
|
|
type: Literal[ResourceType.dataset] = Field(
|
|
default=ResourceType.dataset,
|
|
description="Type of resource, always 'dataset' for datasets",
|
|
)
|
|
|
|
@property
|
|
def dataset_id(self) -> str:
|
|
return self.identifier
|
|
|
|
@property
|
|
def provider_dataset_id(self) -> str | None:
|
|
return self.provider_resource_id
|
|
|
|
|
|
@json_schema_type
|
|
class ListDatasetsResponse(BaseModel):
|
|
"""Response from listing datasets."""
|
|
|
|
data: list[Dataset] = Field(..., description="List of datasets")
|
|
|
|
|
|
# Request models for each endpoint
|
|
|
|
|
|
@json_schema_type
|
|
class RegisterDatasetRequest(BaseModel):
|
|
"""Request model for registering a dataset."""
|
|
|
|
purpose: DatasetPurpose = Field(..., description="The purpose of the dataset.")
|
|
source: DataSource = Field(..., description="The data source of the dataset.")
|
|
metadata: dict[str, Any] | None = Field(
|
|
default=None,
|
|
description="The metadata for the dataset.",
|
|
)
|
|
dataset_id: str | None = Field(
|
|
default=None,
|
|
description="The ID of the dataset. If not provided, an ID will be generated.",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class GetDatasetRequest(BaseModel):
|
|
"""Request model for getting a dataset by ID."""
|
|
|
|
dataset_id: str = Field(..., description="The ID of the dataset to get.")
|
|
|
|
|
|
@json_schema_type
|
|
class UnregisterDatasetRequest(BaseModel):
|
|
"""Request model for unregistering a dataset."""
|
|
|
|
dataset_id: str = Field(..., description="The ID of the dataset to unregister.")
|
|
|
|
|
|
__all__ = [
|
|
"CommonDatasetFields",
|
|
"Dataset",
|
|
"DatasetPurpose",
|
|
"DatasetType",
|
|
"DataSource",
|
|
"RowsDataSource",
|
|
"URIDataSource",
|
|
"ListDatasetsResponse",
|
|
"RegisterDatasetRequest",
|
|
"GetDatasetRequest",
|
|
"UnregisterDatasetRequest",
|
|
]
|