diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 22a1e46f9..2cf38f544 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse +from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.eval import ( BenchmarkConfig, Eval, @@ -160,7 +161,11 @@ class InferenceRouter(Inference): await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( - self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, ) -> List[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. @@ -298,7 +303,12 @@ class InferenceRouter(Inference): completion_text += chunk.event.delta.text if chunk.event.event_type == ChatCompletionResponseEventType.complete: completion_tokens = await self._count_tokens( - [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + [ + CompletionMessage( + content=completion_text, + stop_reason=StopReason.end_of_turn, + ) + ], tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) @@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO): logger.debug("DatasetIORouter.shutdown") pass - async def get_rows_paginated( + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: Optional[Dict[str, Any]] = None, + dataset_id: Optional[str] = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + + async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + start_index: Optional[int] = None, + limit: Optional[int] = None, + ) -> IterrowsResponse: logger.debug( - f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) - return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( + return await self.routing_table.get_provider_impl(dataset_id).iterrows( dataset_id=dataset_id, - rows_in_page=rows_in_page, - page_token=page_token, - filter_condition=filter_condition, + start_index=start_index, + limit=limit, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1be43ec8b..589a03b25 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import logging +import uuid from typing import Any, Dict, List, Optional from pydantic import TypeAdapter @@ -12,7 +13,14 @@ from pydantic import TypeAdapter from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, +) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( @@ -352,34 +360,42 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def register_dataset( self, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, + purpose: DatasetPurpose, + source: DataSource, metadata: Optional[Dict[str, Any]] = None, - ) -> None: - if provider_dataset_id is None: - provider_dataset_id = dataset_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this dataset - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] + dataset_id: Optional[str] = None, + ) -> Dataset: + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if source.type == DatasetType.rows: + provider_id = "localfs" + elif source.type == DatasetType.uri: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + if metadata is None: metadata = {} + dataset = Dataset( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, - dataset_schema=dataset_schema, - url=url, + purpose=purpose, + source=source, metadata=metadata, ) + await self.register_object(dataset) + return dataset async def unregister_dataset(self, dataset_id: str) -> None: dataset = await self.get_dataset(dataset_id) diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 00e949ed6..5ce5bc5d2 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -166,7 +166,7 @@ def run_evaluation_3(): eval_candidate = st.session_state["eval_candidate"] dataset_id = benchmarks[selected_benchmark].dataset_id - rows = llama_stack_api.client.datasetio.get_rows_paginated( + rows = llama_stack_api.client.datasetio.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index c5216e026..afa9ee0ff 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -13,7 +13,7 @@ from urllib.parse import urlparse import pandas from llama_stack.apis.common.content_types import URL -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -128,36 +128,27 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): await self.kvstore.delete(key=key) del self.dataset_infos[dataset_id] - async def get_rows_paginated( + async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + start_index: Optional[int] = None, + limit: Optional[int] = None, + ) -> IterrowsResponse: dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token and not page_token.isnumeric(): - raise ValueError("Invalid page_token") + start_index = start_index or 0 - if page_token is None or len(page_token) == 0: - next_page_token = 0 - else: - next_page_token = int(page_token) - - start = next_page_token - if rows_in_page == -1: + if limit is None or limit == -1: end = len(dataset_info.dataset_impl) else: - end = min(start + rows_in_page, len(dataset_info.dataset_impl)) + end = min(start_index + limit, len(dataset_info.dataset_impl)) - rows = dataset_info.dataset_impl[start:end] + rows = dataset_info.dataset_impl[start_index:end] - return PaginatedRowsResult( - rows=rows, - total_count=len(rows), - next_page_token=str(end), + return IterrowsResponse( + data=rows, + next_index=end if end < len(dataset_info.dataset_impl) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 85b351262..67e2eb193 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -90,7 +90,7 @@ class MetaReferenceEvalImpl( scoring_functions = task_def.scoring_functions dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 941c629e3..482bbd309 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -328,7 +328,7 @@ class LoraFinetuningSingleDevice: batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: async def fetch_rows(dataset_id: str): - return await self.datasetio_api.get_rows_paginated( + return await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 599f5f98c..915c33c8d 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import BasicScoringConfig from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn -from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn +from .scoring_fn.regex_parser_math_response_scoring_fn import ( + RegexParserMathResponseScoringFn, +) from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn @@ -82,7 +84,7 @@ class BasicScoringImpl( dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index a48b6b58b..1f5c3e147 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -167,7 +167,7 @@ class BraintrustScoringImpl( dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 5b1715d9f..c6e0d39c9 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -72,7 +72,7 @@ class LlmAsJudgeScoringImpl( dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - all_rows = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cd4e7f1f1..d59edda30 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional import datasets as hf_datasets -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -73,36 +73,27 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): await self.kvstore.delete(key=key) del self.dataset_infos[dataset_id] - async def get_rows_paginated( + async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + start_index: Optional[int] = None, + limit: Optional[int] = None, + ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] loaded_dataset = load_hf_dataset(dataset_def) - if page_token and not page_token.isnumeric(): - raise ValueError("Invalid page_token") + start_index = start_index or 0 - if page_token is None or len(page_token) == 0: - next_page_token = 0 - else: - next_page_token = int(page_token) - - start = next_page_token - if rows_in_page == -1: + if limit is None or limit == -1: end = len(loaded_dataset) else: - end = min(start + rows_in_page, len(loaded_dataset)) + end = min(start_index + limit, len(loaded_dataset)) - rows = [loaded_dataset[i] for i in range(start, end)] + rows = [loaded_dataset[i] for i in range(start_index, end)] - return PaginatedRowsResult( - rows=rows, - total_count=len(rows), - next_page_token=str(end), + return IterrowsResponse( + data=rows, + next_index=end if end < len(loaded_dataset) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: