mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 02:55:50 +00:00
feat: convert Datasets API to use FastAPI router (#4359)
# What does this PR do? Convert the Datasets API from webmethod decorators to FastAPI router pattern. Fixes: https://github.com/llamastack/llama-stack/issues/4344 ## Test Plan CI Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
56f946f3f5
commit
700663028f
13 changed files with 716 additions and 335 deletions
152
src/llama_stack_api/datasets/models.py
Normal file
152
src/llama_stack_api/datasets/models.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Datasets API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Datasets API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
class DatasetPurpose(StrEnum):
|
||||
"""Purpose of the dataset. Each purpose has a required input data schema."""
|
||||
|
||||
post_training_messages = "post-training/messages"
|
||||
"""The dataset contains messages used for post-training."""
|
||||
eval_question_answer = "eval/question-answer"
|
||||
"""The dataset contains a question column and an answer column."""
|
||||
eval_messages_answer = "eval/messages-answer"
|
||||
"""The dataset contains a messages column with list of messages and an answer column."""
|
||||
|
||||
|
||||
class DatasetType(Enum):
|
||||
"""Type of the dataset source."""
|
||||
|
||||
uri = "uri"
|
||||
"""The dataset can be obtained from a URI."""
|
||||
rows = "rows"
|
||||
"""The dataset is stored in rows."""
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class URIDataSource(BaseModel):
|
||||
"""A dataset that can be obtained from a URI."""
|
||||
|
||||
type: Literal["uri"] = Field(default="uri", description="The type of data source.")
|
||||
uri: str = Field(
|
||||
...,
|
||||
description='The dataset can be obtained from a URI. E.g. "https://mywebsite.com/mydata.jsonl", "lsfs://mydata.jsonl", "data:csv;base64,{base64_content}"',
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RowsDataSource(BaseModel):
|
||||
"""A dataset stored in rows."""
|
||||
|
||||
type: Literal["rows"] = Field(default="rows", description="The type of data source.")
|
||||
rows: list[dict[str, Any]] = Field(
|
||||
...,
|
||||
description='The dataset is stored in rows. E.g. [{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}]',
|
||||
)
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
URIDataSource | RowsDataSource,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
"""Common fields for a dataset."""
|
||||
|
||||
purpose: DatasetPurpose = Field(..., description="Purpose of the dataset indicating its intended use")
|
||||
source: DataSource = Field(..., description="Data source configuration for the dataset")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
"""Dataset resource for storing and accessing training or evaluation data."""
|
||||
|
||||
type: Literal[ResourceType.dataset] = Field(
|
||||
default=ResourceType.dataset,
|
||||
description="Type of resource, always 'dataset' for datasets",
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
"""Response from listing datasets."""
|
||||
|
||||
data: list[Dataset] = Field(..., description="List of datasets")
|
||||
|
||||
|
||||
# Request models for each endpoint
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegisterDatasetRequest(BaseModel):
|
||||
"""Request model for registering a dataset."""
|
||||
|
||||
purpose: DatasetPurpose = Field(..., description="The purpose of the dataset.")
|
||||
source: DataSource = Field(..., description="The data source of the dataset.")
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="The metadata for the dataset.",
|
||||
)
|
||||
dataset_id: str | None = Field(
|
||||
default=None,
|
||||
description="The ID of the dataset. If not provided, an ID will be generated.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GetDatasetRequest(BaseModel):
|
||||
"""Request model for getting a dataset by ID."""
|
||||
|
||||
dataset_id: str = Field(..., description="The ID of the dataset to get.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnregisterDatasetRequest(BaseModel):
|
||||
"""Request model for unregistering a dataset."""
|
||||
|
||||
dataset_id: str = Field(..., description="The ID of the dataset to unregister.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CommonDatasetFields",
|
||||
"Dataset",
|
||||
"DatasetPurpose",
|
||||
"DatasetType",
|
||||
"DataSource",
|
||||
"RowsDataSource",
|
||||
"URIDataSource",
|
||||
"ListDatasetsResponse",
|
||||
"RegisterDatasetRequest",
|
||||
"GetDatasetRequest",
|
||||
"UnregisterDatasetRequest",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue