Merge branch 'pr1573' into api_2

This commit is contained in:
Xi Yan 2025-03-13 14:49:04 -07:00
commit 0c37951395
4 changed files with 246 additions and 354 deletions

View file

@ -37,8 +37,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
@ -54,5 +54,7 @@ class DatasetIO(Protocol):
"""
...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
@webmethod(route="/datasets/{dataset_id}/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -60,6 +60,7 @@ class DatasetPurpose(Enum):
"answer": "Paris"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
@ -75,11 +76,10 @@ class DatasetPurpose(Enum):
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar huggingface: The dataset is stored in Huggingface.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
huggingface = "huggingface"
uri = "uri"
rows = "rows"
@ -92,30 +92,11 @@ class URIDataSource(BaseModel):
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
class HuggingfaceDataSourceFields(BaseModel):
"""The fields for a Huggingface dataset.
:param path: The path to the dataset in Huggingface. E.g.
- "llamastack/simpleqa"
:param params: The parameters for the dataset.
"""
path: str
params: Dict[str, Any]
@json_schema_type
class HuggingfaceDataSource(BaseModel):
"""A dataset stored in Huggingface.
:param type: The type of the data source.
:param huggingface: The fields for a Huggingface dataset.
"""
type: Literal["huggingface"] = "huggingface"
huggingface: HuggingfaceDataSourceFields
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
@ -124,13 +105,14 @@ class RowsDataSource(BaseModel):
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
Union[URIDataSource, HuggingfaceDataSource, RowsDataSource],
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
@ -141,6 +123,7 @@ class CommonDatasetFields(BaseModel):
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
@ -237,13 +220,12 @@ class Datasets(Protocol):
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "huggingface",
"huggingface": {
"dataset_path": "tatsu-lab/alpaca",
"params": {
"split": "train"
}
}
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
@ -258,7 +240,7 @@ class Datasets(Protocol):
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, a random ID will be generated.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
"""
...