This commit is contained in:
Xi Yan 2025-03-11 18:29:55 -07:00
parent 02aa9a1e85
commit 0e47c65051
3 changed files with 294 additions and 82 deletions

View file

@ -5,12 +5,12 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, Annotated, Union
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type, webmethod, register_schema
class Schema(Enum):
@ -29,9 +29,42 @@ class Schema(Enum):
# TODO: add more schemas here
class DatasetType(Enum):
huggingface = "huggingface"
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataReference(BaseModel):
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class HuggingfaceDataReference(BaseModel):
type: Literal["huggingface"] = "huggingface"
dataset_path: str
params: Dict[str, Any]
@json_schema_type
class RowsDataReference(BaseModel):
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataReference = register_schema(
Annotated[
Union[URIDataReference, HuggingfaceDataReference, RowsDataReference],
Field(discriminator="type"),
],
name="DataReference",
)
class CommonDatasetFields(BaseModel):
schema: Schema
uri: str
data_reference: DataReference
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
@ -66,8 +99,7 @@ class Datasets(Protocol):
async def register_dataset(
self,
schema: Schema,
uri: str,
uri_params: Optional[Dict[str, Any]] = None,
data_reference: DataReference,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> Dataset:
@ -76,13 +108,26 @@ class Datasets(Protocol):
:param schema: The schema format of the dataset. One of
- jsonl_messages: The dataset is a JSONL file with messages in column format
:param uri: The URI of the dataset. Examples:
- file://mydata.jsonl
- s3://mybucket/myfile.jsonl
- https://mywebsite.com/myfile.jsonl
- huggingface://tatsu-lab/alpaca
:param uri_params: The parameters for the URI.
- E.g. If URL is a huggingface dataset, parameters could be uri_params={"split": "train"}
:param data_reference: The data reference of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "huggingface",
"dataset_path": "tatsu-lab/alpaca",
"params": {
"split": "train"
}
}
- {
"type": "rows",
"rows": [{"message": "Hello, world!"}]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, a random ID will be generated.