datasetio->datasets

This commit is contained in:
Xi Yan 2025-03-13 14:47:06 -07:00
parent 78ec3d98f6
commit 89885fd2fa
2 changed files with 20 additions and 36 deletions

View file

@ -37,8 +37,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
@ -54,5 +54,7 @@ class DatasetIO(Protocol):
"""
...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
@webmethod(route="/datasets/{dataset_id}/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -39,6 +39,7 @@ class DatasetPurpose(Enum):
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
@ -49,11 +50,10 @@ class DatasetPurpose(Enum):
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar huggingface: The dataset is stored in Huggingface.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
huggingface = "huggingface"
uri = "uri"
rows = "rows"
@ -66,30 +66,11 @@ class URIDataSource(BaseModel):
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
class HuggingfaceDataSourceFields(BaseModel):
"""The fields for a Huggingface dataset.
:param path: The path to the dataset in Huggingface. E.g.
- "llamastack/simpleqa"
:param params: The parameters for the dataset.
"""
path: str
params: Dict[str, Any]
@json_schema_type
class HuggingfaceDataSource(BaseModel):
"""A dataset stored in Huggingface.
:param type: The type of the data source.
:param huggingface: The fields for a Huggingface dataset.
"""
type: Literal["huggingface"] = "huggingface"
huggingface: HuggingfaceDataSourceFields
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
@ -98,13 +79,14 @@ class RowsDataSource(BaseModel):
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
Union[URIDataSource, HuggingfaceDataSource, RowsDataSource],
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
@ -115,6 +97,7 @@ class CommonDatasetFields(BaseModel):
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
@ -190,13 +173,12 @@ class Datasets(Protocol):
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "huggingface",
"huggingface": {
"dataset_path": "tatsu-lab/alpaca",
"params": {
"split": "train"
}
}
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",