forked from phoenix-oss/llama-stack-mirror
datasetio->datasets
This commit is contained in:
parent
78ec3d98f6
commit
89885fd2fa
2 changed files with 20 additions and 36 deletions
|
@ -37,8 +37,8 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasetio/rows", method="GET")
|
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
|
||||||
async def get_rows_paginated(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
rows_in_page: int,
|
rows_in_page: int,
|
||||||
|
@ -54,5 +54,7 @@ class DatasetIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/rows", method="POST")
|
@webmethod(route="/datasets/{dataset_id}/rows", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(
|
||||||
|
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -39,6 +39,7 @@ class DatasetPurpose(Enum):
|
||||||
"answer": "John Doe"
|
"answer": "John Doe"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
post_training_messages = "post-training/messages"
|
post_training_messages = "post-training/messages"
|
||||||
eval_question_answer = "eval/question-answer"
|
eval_question_answer = "eval/question-answer"
|
||||||
eval_messages_answer = "eval/messages-answer"
|
eval_messages_answer = "eval/messages-answer"
|
||||||
|
@ -49,11 +50,10 @@ class DatasetPurpose(Enum):
|
||||||
class DatasetType(Enum):
|
class DatasetType(Enum):
|
||||||
"""
|
"""
|
||||||
Type of the dataset source.
|
Type of the dataset source.
|
||||||
:cvar huggingface: The dataset is stored in Huggingface.
|
:cvar uri: The dataset can be obtained from a URI.
|
||||||
:cvar uri: The dataset can be obtained from a URI.
|
:cvar rows: The dataset is stored in rows.
|
||||||
:cvar rows: The dataset is stored in rows.
|
|
||||||
"""
|
"""
|
||||||
huggingface = "huggingface"
|
|
||||||
uri = "uri"
|
uri = "uri"
|
||||||
rows = "rows"
|
rows = "rows"
|
||||||
|
|
||||||
|
@ -66,30 +66,11 @@ class URIDataSource(BaseModel):
|
||||||
- "lsfs://mydata.jsonl"
|
- "lsfs://mydata.jsonl"
|
||||||
- "data:csv;base64,{base64_content}"
|
- "data:csv;base64,{base64_content}"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["uri"] = "uri"
|
type: Literal["uri"] = "uri"
|
||||||
uri: str
|
uri: str
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDataSourceFields(BaseModel):
|
|
||||||
"""The fields for a Huggingface dataset.
|
|
||||||
:param path: The path to the dataset in Huggingface. E.g.
|
|
||||||
- "llamastack/simpleqa"
|
|
||||||
:param params: The parameters for the dataset.
|
|
||||||
"""
|
|
||||||
path: str
|
|
||||||
params: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class HuggingfaceDataSource(BaseModel):
|
|
||||||
"""A dataset stored in Huggingface.
|
|
||||||
:param type: The type of the data source.
|
|
||||||
:param huggingface: The fields for a Huggingface dataset.
|
|
||||||
"""
|
|
||||||
type: Literal["huggingface"] = "huggingface"
|
|
||||||
huggingface: HuggingfaceDataSourceFields
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RowsDataSource(BaseModel):
|
class RowsDataSource(BaseModel):
|
||||||
"""A dataset stored in rows.
|
"""A dataset stored in rows.
|
||||||
|
@ -98,13 +79,14 @@ class RowsDataSource(BaseModel):
|
||||||
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = register_schema(
|
DataSource = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[URIDataSource, HuggingfaceDataSource, RowsDataSource],
|
Union[URIDataSource, RowsDataSource],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
],
|
||||||
name="DataSource",
|
name="DataSource",
|
||||||
|
@ -115,6 +97,7 @@ class CommonDatasetFields(BaseModel):
|
||||||
"""
|
"""
|
||||||
Common fields for a dataset.
|
Common fields for a dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
purpose: DatasetPurpose
|
||||||
source: DataSource
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
|
@ -190,13 +173,12 @@ class Datasets(Protocol):
|
||||||
"uri": "lsfs://mydata.jsonl"
|
"uri": "lsfs://mydata.jsonl"
|
||||||
}
|
}
|
||||||
- {
|
- {
|
||||||
"type": "huggingface",
|
"type": "uri",
|
||||||
"huggingface": {
|
"uri": "data:csv;base64,{base64_content}"
|
||||||
"dataset_path": "tatsu-lab/alpaca",
|
}
|
||||||
"params": {
|
- {
|
||||||
"split": "train"
|
"type": "uri",
|
||||||
}
|
"uri": "huggingface://llamastack/simpleqa?split=train"
|
||||||
}
|
|
||||||
}
|
}
|
||||||
- {
|
- {
|
||||||
"type": "rows",
|
"type": "rows",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue