This commit is contained in:
Xi Yan 2025-03-12 18:46:40 -07:00
parent 18de4cd08a
commit a3173e8284
3 changed files with 95 additions and 41 deletions

View file

@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
class DatasetPurpose(Enum):
"""
Purpose of the dataset. Each type has a different column format.
:cvar tuning/messages: The dataset contains messages used for post-training. Examples:
:cvar post-training/messages: The dataset contains messages used for post-training. Examples:
{
"messages": [
{"role": "user", "content": "Hello, world!"},
@ -25,12 +25,19 @@ class DatasetPurpose(Enum):
}
"""
tuning_messages = "tuning/messages"
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar huggingface: The dataset is stored in Huggingface.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
huggingface = "huggingface"
uri = "uri"
rows = "rows"
@ -38,19 +45,36 @@ class DatasetType(Enum):
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class HuggingfaceDataSource(BaseModel):
"""A dataset stored in Huggingface.
:param path: The path to the dataset in Huggingface. E.g.
- "llamastack/simpleqa"
:param params: The parameters for the dataset.
"""
type: Literal["huggingface"] = "huggingface"
dataset_path: str
path: str
params: Dict[str, Any]
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
@ -65,7 +89,10 @@ DataSource = register_schema(
class CommonDatasetFields(BaseModel):
schema: Schema
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
data_source: DataSource
metadata: Dict[str, Any] = Field(
default_factory=dict,
@ -108,9 +135,10 @@ class Datasets(Protocol):
"""
Register a new dataset.
:param schema: The schema format of the dataset. One of
- messages: The dataset contains a messages column with list of messages for post-training.
:param data_source: The data source of the dataset. Examples:
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
- "eval/question-answer": The dataset contains a question and answer column.
:param source: The data source of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"