datasetio->datasets

This commit is contained in:
Xi Yan 2025-03-13 14:47:06 -07:00
parent 78ec3d98f6
commit 89885fd2fa
2 changed files with 20 additions and 36 deletions

View file

@ -37,8 +37,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET") @webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, rows_in_page: int,
@ -54,5 +54,7 @@ class DatasetIO(Protocol):
""" """
... ...
@webmethod(route="/datasetio/rows", method="POST") @webmethod(route="/datasets/{dataset_id}/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -39,6 +39,7 @@ class DatasetPurpose(Enum):
"answer": "John Doe" "answer": "John Doe"
} }
""" """
post_training_messages = "post-training/messages" post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer" eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer" eval_messages_answer = "eval/messages-answer"
@ -49,11 +50,10 @@ class DatasetPurpose(Enum):
class DatasetType(Enum): class DatasetType(Enum):
""" """
Type of the dataset source. Type of the dataset source.
:cvar huggingface: The dataset is stored in Huggingface. :cvar uri: The dataset can be obtained from a URI.
:cvar uri: The dataset can be obtained from a URI. :cvar rows: The dataset is stored in rows.
:cvar rows: The dataset is stored in rows.
""" """
huggingface = "huggingface"
uri = "uri" uri = "uri"
rows = "rows" rows = "rows"
@ -66,30 +66,11 @@ class URIDataSource(BaseModel):
- "lsfs://mydata.jsonl" - "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}" - "data:csv;base64,{base64_content}"
""" """
type: Literal["uri"] = "uri" type: Literal["uri"] = "uri"
uri: str uri: str
class HuggingfaceDataSourceFields(BaseModel):
"""The fields for a Huggingface dataset.
:param path: The path to the dataset in Huggingface. E.g.
- "llamastack/simpleqa"
:param params: The parameters for the dataset.
"""
path: str
params: Dict[str, Any]
@json_schema_type
class HuggingfaceDataSource(BaseModel):
"""A dataset stored in Huggingface.
:param type: The type of the data source.
:param huggingface: The fields for a Huggingface dataset.
"""
type: Literal["huggingface"] = "huggingface"
huggingface: HuggingfaceDataSourceFields
@json_schema_type @json_schema_type
class RowsDataSource(BaseModel): class RowsDataSource(BaseModel):
"""A dataset stored in rows. """A dataset stored in rows.
@ -98,13 +79,14 @@ class RowsDataSource(BaseModel):
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]} {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
] ]
""" """
type: Literal["rows"] = "rows" type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]] rows: List[Dict[str, Any]]
DataSource = register_schema( DataSource = register_schema(
Annotated[ Annotated[
Union[URIDataSource, HuggingfaceDataSource, RowsDataSource], Union[URIDataSource, RowsDataSource],
Field(discriminator="type"), Field(discriminator="type"),
], ],
name="DataSource", name="DataSource",
@ -115,6 +97,7 @@ class CommonDatasetFields(BaseModel):
""" """
Common fields for a dataset. Common fields for a dataset.
""" """
purpose: DatasetPurpose purpose: DatasetPurpose
source: DataSource source: DataSource
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
@ -190,13 +173,12 @@ class Datasets(Protocol):
"uri": "lsfs://mydata.jsonl" "uri": "lsfs://mydata.jsonl"
} }
- { - {
"type": "huggingface", "type": "uri",
"huggingface": { "uri": "data:csv;base64,{base64_content}"
"dataset_path": "tatsu-lab/alpaca", }
"params": { - {
"split": "train" "type": "uri",
} "uri": "huggingface://llamastack/simpleqa?split=train"
}
} }
- { - {
"type": "rows", "type": "rows",