diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 6a04a6329..2c6c8e981 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -37,8 +37,8 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/datasetio/rows", method="GET") - async def get_rows_paginated( + @webmethod(route="/datasets/{dataset_id}/iterrows", method="GET") + async def iterrows( self, dataset_id: str, rows_in_page: int, @@ -54,5 +54,7 @@ class DatasetIO(Protocol): """ ... - @webmethod(route="/datasetio/rows", method="POST") - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... + @webmethod(route="/datasets/{dataset_id}/rows", method="POST") + async def append_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e1285ef9a..c25b861c3 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -39,6 +39,7 @@ class DatasetPurpose(Enum): "answer": "John Doe" } """ + post_training_messages = "post-training/messages" eval_question_answer = "eval/question-answer" eval_messages_answer = "eval/messages-answer" @@ -49,11 +50,10 @@ class DatasetPurpose(Enum): class DatasetType(Enum): """ Type of the dataset source. - :cvar huggingface: The dataset is stored in Huggingface. - :cvar uri: The dataset can be obtained from a URI. - :cvar rows: The dataset is stored in rows. + :cvar uri: The dataset can be obtained from a URI. + :cvar rows: The dataset is stored in rows. """ - huggingface = "huggingface" + uri = "uri" rows = "rows" @@ -66,30 +66,11 @@ class URIDataSource(BaseModel): - "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}" """ + type: Literal["uri"] = "uri" uri: str -class HuggingfaceDataSourceFields(BaseModel): - """The fields for a Huggingface dataset. - :param path: The path to the dataset in Huggingface. E.g. - - "llamastack/simpleqa" - :param params: The parameters for the dataset. - """ - path: str - params: Dict[str, Any] - - -@json_schema_type -class HuggingfaceDataSource(BaseModel): - """A dataset stored in Huggingface. - :param type: The type of the data source. - :param huggingface: The fields for a Huggingface dataset. - """ - type: Literal["huggingface"] = "huggingface" - huggingface: HuggingfaceDataSourceFields - - @json_schema_type class RowsDataSource(BaseModel): """A dataset stored in rows. @@ -98,13 +79,14 @@ class RowsDataSource(BaseModel): {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]} ] """ + type: Literal["rows"] = "rows" rows: List[Dict[str, Any]] DataSource = register_schema( Annotated[ - Union[URIDataSource, HuggingfaceDataSource, RowsDataSource], + Union[URIDataSource, RowsDataSource], Field(discriminator="type"), ], name="DataSource", @@ -115,6 +97,7 @@ class CommonDatasetFields(BaseModel): """ Common fields for a dataset. """ + purpose: DatasetPurpose source: DataSource metadata: Dict[str, Any] = Field( @@ -190,13 +173,12 @@ class Datasets(Protocol): "uri": "lsfs://mydata.jsonl" } - { - "type": "huggingface", - "huggingface": { - "dataset_path": "tatsu-lab/alpaca", - "params": { - "split": "train" - } - } + "type": "uri", + "uri": "data:csv;base64,{base64_content}" + } + - { + "type": "uri", + "uri": "huggingface://llamastack/simpleqa?split=train" } - { "type": "rows",