diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index b321b260e..49a07c9b1 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -21,7 +21,7 @@ class PaginatedRowsResult(BaseModel): class DatasetStore(Protocol): - def get_dataset(self, identifier: str) -> DatasetDefWithProvider: ... + def get_dataset(self, dataset_id: str) -> Dataset: ... @runtime_checkable diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 1695c888b..896fd818e 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -10,19 +10,16 @@ from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field +from pydantic import Field from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.resource import Resource @json_schema_type -class DatasetDef(BaseModel): - identifier: str = Field( - description="A unique name for the dataset", - ) - dataset_schema: Dict[str, ParamType] = Field( - description="The schema definition for this dataset", - ) +class Dataset(Resource): + type: Literal["dataset"] = "dataset" + schema: Dict[str, ParamType] url: URL metadata: Dict[str, Any] = Field( default_factory=dict, @@ -30,26 +27,23 @@ class DatasetDef(BaseModel): ) -@json_schema_type -class DatasetDefWithProvider(DatasetDef): - type: Literal["dataset"] = "dataset" - provider_id: str = Field( - description="ID of the provider which serves this dataset", - ) - - class Datasets(Protocol): @webmethod(route="/datasets/register", method="POST") async def register_dataset( self, - dataset_def: DatasetDefWithProvider, + dataset_id: str, + schema: Dict[str, ParamType], + url: URL, + provider_dataset_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: ... @webmethod(route="/datasets/get", method="GET") async def get_dataset( self, - dataset_identifier: str, - ) -> Optional[DatasetDefWithProvider]: ... + dataset_id: str, + ) -> Optional[Dataset]: ... @webmethod(route="/datasets/list", method="GET") - async def list_datasets(self) -> List[DatasetDefWithProvider]: ... + async def list_datasets(self) -> List[Dataset]: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ebc511b02..9098f4331 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -34,7 +34,7 @@ RoutableObject = Union[ Model, Shield, MemoryBank, - DatasetDef, + Dataset, ScoringFnDef, ] @@ -44,7 +44,7 @@ RoutableObjectWithProvider = Annotated[ Model, Shield, MemoryBank, - DatasetDefWithProvider, + Dataset, ScoringFnDefWithProvider, ], Field(discriminator="type"), diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index aa61580b2..ad246789e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -17,6 +17,9 @@ from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_models.llama3.api.datatypes import URL + +from llama_stack.apis.common.type_system import ParamType from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 @@ -94,8 +97,6 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.datasetio: p.dataset_store = self - datasets = await p.list_datasets() - await add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self @@ -302,16 +303,42 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): - async def list_datasets(self) -> List[DatasetDefWithProvider]: + async def list_datasets(self) -> List[Dataset]: return await self.get_all_with_type("dataset") - async def get_dataset( - self, dataset_identifier: str - ) -> Optional[DatasetDefWithProvider]: - return await self.get_object_by_identifier(dataset_identifier) + async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: + return await self.get_object_by_identifier(dataset_id) - async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: - await self.register_object(dataset_def) + async def register_dataset( + self, + dataset_id: str, + schema: Dict[str, ParamType], + url: URL, + provider_dataset_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + if provider_dataset_id is None: + provider_dataset_id = dataset_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this dataset + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if metadata is None: + metadata = {} + dataset = Dataset( + identifier=dataset_id, + provider_resource_id=provider_dataset_id, + provider_id=provider_id, + schema=schema, + url=url, + metadata=metadata, + ) + await self.register_object(dataset) class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py index 598ca5cfd..cd143a3ef 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional +from typing import Optional from llama_stack.apis.datasetio import * # noqa: F403 @@ -15,7 +15,7 @@ from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_u from .config import HuggingfaceDatasetIOConfig -def load_hf_dataset(dataset_def: DatasetDef): +def load_hf_dataset(dataset_def: Dataset): if dataset_def.metadata.get("path", None): return hf_datasets.load_dataset(**dataset_def.metadata) @@ -41,13 +41,10 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def register_dataset( self, - dataset_def: DatasetDef, + dataset_def: Dataset, ) -> None: self.dataset_infos[dataset_def.identifier] = dataset_def - async def list_datasets(self) -> List[DatasetDef]: - return list(self.dataset_infos.values()) - async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index ed2033494..aeb0be742 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,7 +11,7 @@ from urllib.parse import urlparse from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field -from llama_stack.apis.datasets import DatasetDef +from llama_stack.apis.datasets import Dataset from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model @@ -57,9 +57,7 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): - async def list_datasets(self) -> List[DatasetDef]: ... - - async def register_dataset(self, dataset_def: DatasetDef) -> None: ... + async def register_dataset(self, dataset: Dataset) -> None: ... class ScoringFunctionsProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index d8c100684..f54905a6b 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional +from typing import Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -37,12 +37,12 @@ class BaseDataset(ABC): @dataclass class DatasetInfo: - dataset_def: DatasetDef + dataset_def: Dataset dataset_impl: BaseDataset class PandasDataframeDataset(BaseDataset): - def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None: + def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dataset_def = dataset_def self.df = None @@ -60,9 +60,9 @@ class PandasDataframeDataset(BaseDataset): def _validate_dataset_schema(self, df) -> pandas.DataFrame: # note that we will drop any columns in dataset that are not in the schema - df = df[self.dataset_def.dataset_schema.keys()] + df = df[self.dataset_def.schema.keys()] # check all columns in dataset schema are present - assert len(df.columns) == len(self.dataset_def.dataset_schema) + assert len(df.columns) == len(self.dataset_def.schema) # TODO: type checking against column types in dataset schema return df @@ -89,17 +89,14 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def register_dataset( self, - dataset_def: DatasetDef, + dataset: Dataset, ) -> None: - dataset_impl = PandasDataframeDataset(dataset_def) - self.dataset_infos[dataset_def.identifier] = DatasetInfo( - dataset_def=dataset_def, + dataset_impl = PandasDataframeDataset(dataset) + self.dataset_infos[dataset.identifier] = DatasetInfo( + dataset_def=dataset, dataset_impl=dataset_impl, ) - async def list_datasets(self) -> List[DatasetDef]: - return [i.dataset_def for i in self.dataset_infos.values()] - async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index c02794c50..2b2d57ddd 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -55,15 +55,11 @@ async def register_dataset( "generated_answer": StringType(), } - dataset = DatasetDefWithProvider( - identifier=dataset_id, - provider_id="", - url=URL( - uri=test_url, - ), - dataset_schema=dataset_schema, + await datasets_impl.register_dataset( + dataset_id=dataset_id, + schema=dataset_schema, + url=URL(uri=test_url), ) - await datasets_impl.register_dataset(dataset) class TestDatasetIO: