diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index fe9d30e2a..6f10f93e9 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -4,19 +4,35 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol from pydantic import BaseModel, Field -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType from llama_stack.schema_utils import json_schema_type, webmethod +@json_schema_type +class Schema(Enum): + """ + Schema of the dataset. Each type has a different column format. + + :cvar jsonl_messages: The dataset is a JSONL file with messages. Examples: + { + "messages": [ + {"role": "user", "content": "Hello, world!"}, + {"role": "assistant", "content": "Hello, world!"}, + ] + } + """ + + jsonl_messages = "jsonl_messages" + + class CommonDatasetFields(BaseModel): - dataset_schema: Dict[str, ParamType] - url: URL + schema: Schema + uri: str metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this dataset", @@ -50,13 +66,10 @@ class Datasets(Protocol): @webmethod(route="/datasets", method="POST") async def register_dataset( self, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, + schema: Schema, + uri: str, metadata: Optional[Dict[str, Any]] = None, - ) -> None: ... + ) -> Dataset: ... @webmethod(route="/datasets/{dataset_id:path}", method="GET") async def get_dataset(