feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)

# What does this PR do?
- fix datasets api signature mis-match so that llama stack run can start

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
llama stack run
```
<img width="626" alt="image"
src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-15 14:56:11 -07:00 committed by GitHub
parent 72ccdc19a8
commit 2c9d624910
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 105 additions and 80 deletions

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
) )
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
Eval, Eval,
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics( def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]: ) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics. """Constructs a list of MetricEvent objects containing token usage metrics.
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete: if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], [
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format, tool_config.tool_prompt_format,
) )
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown") logger.debug("DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult:
logger.debug( logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, start_index=start_index,
page_token=page_token, limit=limit,
filter_condition=filter_condition,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -12,7 +13,14 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ( from llama_stack.apis.scoring_functions import (
@ -352,34 +360,42 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def register_dataset( async def register_dataset(
self, self,
dataset_id: str, purpose: DatasetPurpose,
dataset_schema: Dict[str, ParamType], source: DataSource,
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> None: dataset_id: Optional[str] = None,
if provider_dataset_id is None: ) -> Dataset:
provider_dataset_id = dataset_id if not dataset_id:
if provider_id is None: dataset_id = f"dataset-{str(uuid.uuid4())}"
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1: provider_dataset_id = dataset_id
provider_id = list(self.impls_by_provider_id.keys())[0]
# infer provider from source
if source.type == DatasetType.rows:
provider_id = "localfs"
elif source.type == DatasetType.uri:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else: else:
raise ValueError( provider_id = "localfs"
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" else:
) raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = Dataset( dataset = Dataset(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,
dataset_schema=dataset_schema, purpose=purpose,
url=url, source=source,
metadata=metadata, metadata=metadata,
) )
await self.register_object(dataset) await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None: async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id) dataset = await self.get_dataset(dataset_id)

View file

@ -166,7 +166,7 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"] eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated( rows = llama_stack_api.client.datasetio.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -13,7 +13,7 @@ from urllib.parse import urlparse
import pandas import pandas
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -128,36 +128,27 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id) dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load() dataset_info.dataset_impl.load()
if page_token and not page_token.isnumeric(): start_index = start_index or 0
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0: if limit is None or limit == -1:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl) end = len(dataset_info.dataset_impl)
else: else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl)) end = min(start_index + limit, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end] rows = dataset_info.dataset_impl[start_index:end]
return PaginatedRowsResult( return IterrowsResponse(
rows=rows, data=rows,
total_count=len(rows), next_index=end if end < len(dataset_info.dataset_impl) else None,
next_page_token=str(end),
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -90,7 +90,7 @@ class MetaReferenceEvalImpl(
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
) )

View file

@ -328,7 +328,7 @@ class LoraFinetuningSingleDevice:
batch_size: int, batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str): async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -82,7 +84,7 @@ class BasicScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -167,7 +167,7 @@ class BraintrustScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -72,7 +72,7 @@ class LlmAsJudgeScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
import datasets as hf_datasets import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -73,36 +73,27 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key) await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id] del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None, ) -> IterrowsResponse:
) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) loaded_dataset = load_hf_dataset(dataset_def)
if page_token and not page_token.isnumeric(): start_index = start_index or 0
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0: if limit is None or limit == -1:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(loaded_dataset) end = len(loaded_dataset)
else: else:
end = min(start + rows_in_page, len(loaded_dataset)) end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)] rows = [loaded_dataset[i] for i in range(start_index, end)]
return PaginatedRowsResult( return IterrowsResponse(
rows=rows, data=rows,
total_count=len(rows), next_index=end if end < len(loaded_dataset) else None,
next_page_token=str(end),
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: