forked from phoenix-oss/llama-stack-mirror
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
# What does this PR do? - fix datasets api signature mis-match so that llama stack run can start [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama stack run ``` <img width="626" alt="image" src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41" /> [//]: # (## Documentation)
This commit is contained in:
parent
72ccdc19a8
commit
2c9d624910
10 changed files with 105 additions and 80 deletions
|
@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
BenchmarkConfig,
|
BenchmarkConfig,
|
||||||
Eval,
|
Eval,
|
||||||
|
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
model: Model,
|
||||||
) -> List[MetricEvent]:
|
) -> List[MetricEvent]:
|
||||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||||
|
|
||||||
|
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
[
|
||||||
|
CompletionMessage(
|
||||||
|
content=completion_text,
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
],
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
)
|
)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
logger.debug("DatasetIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
dataset_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
)
|
||||||
|
await self.routing_table.register_dataset(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
rows_in_page: int,
|
start_index: Optional[int] = None,
|
||||||
page_token: Optional[str] = None,
|
limit: Optional[int] = None,
|
||||||
filter_condition: Optional[str] = None,
|
) -> IterrowsResponse:
|
||||||
) -> PaginatedRowsResult:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=rows_in_page,
|
start_index=start_index,
|
||||||
page_token=page_token,
|
limit=limit,
|
||||||
filter_condition=filter_condition,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
@ -12,7 +13,14 @@ from pydantic import TypeAdapter
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
from llama_stack.apis.datasets import (
|
||||||
|
Dataset,
|
||||||
|
DatasetPurpose,
|
||||||
|
Datasets,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
@ -352,34 +360,42 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
purpose: DatasetPurpose,
|
||||||
dataset_schema: Dict[str, ParamType],
|
source: DataSource,
|
||||||
url: URL,
|
|
||||||
provider_dataset_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
dataset_id: Optional[str] = None,
|
||||||
if provider_dataset_id is None:
|
) -> Dataset:
|
||||||
provider_dataset_id = dataset_id
|
if not dataset_id:
|
||||||
if provider_id is None:
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
# If provider_id not specified, use the only provider if it supports this dataset
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
provider_dataset_id = dataset_id
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
# infer provider from source
|
||||||
|
if source.type == DatasetType.rows:
|
||||||
|
provider_id = "localfs"
|
||||||
|
elif source.type == DatasetType.uri:
|
||||||
|
# infer provider from uri
|
||||||
|
if source.uri.startswith("huggingface"):
|
||||||
|
provider_id = "huggingface"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
provider_id = "localfs"
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
else:
|
||||||
)
|
raise ValueError(f"Unknown data source type: {source.type}")
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
dataset_schema=dataset_schema,
|
purpose=purpose,
|
||||||
url=url,
|
source=source,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.register_object(dataset)
|
await self.register_object(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
dataset = await self.get_dataset(dataset_id)
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
|
|
@ -166,7 +166,7 @@ def run_evaluation_3():
|
||||||
eval_candidate = st.session_state["eval_candidate"]
|
eval_candidate = st.session_state["eval_candidate"]
|
||||||
|
|
||||||
dataset_id = benchmarks[selected_benchmark].dataset_id
|
dataset_id = benchmarks[selected_benchmark].dataset_id
|
||||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
rows = llama_stack_api.client.datasetio.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from urllib.parse import urlparse
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
@ -128,36 +128,27 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
await self.kvstore.delete(key=key)
|
await self.kvstore.delete(key=key)
|
||||||
del self.dataset_infos[dataset_id]
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
rows_in_page: int,
|
start_index: Optional[int] = None,
|
||||||
page_token: Optional[str] = None,
|
limit: Optional[int] = None,
|
||||||
filter_condition: Optional[str] = None,
|
) -> IterrowsResponse:
|
||||||
) -> PaginatedRowsResult:
|
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_info = self.dataset_infos.get(dataset_id)
|
||||||
dataset_info.dataset_impl.load()
|
dataset_info.dataset_impl.load()
|
||||||
|
|
||||||
if page_token and not page_token.isnumeric():
|
start_index = start_index or 0
|
||||||
raise ValueError("Invalid page_token")
|
|
||||||
|
|
||||||
if page_token is None or len(page_token) == 0:
|
if limit is None or limit == -1:
|
||||||
next_page_token = 0
|
|
||||||
else:
|
|
||||||
next_page_token = int(page_token)
|
|
||||||
|
|
||||||
start = next_page_token
|
|
||||||
if rows_in_page == -1:
|
|
||||||
end = len(dataset_info.dataset_impl)
|
end = len(dataset_info.dataset_impl)
|
||||||
else:
|
else:
|
||||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
end = min(start_index + limit, len(dataset_info.dataset_impl))
|
||||||
|
|
||||||
rows = dataset_info.dataset_impl[start:end]
|
rows = dataset_info.dataset_impl[start_index:end]
|
||||||
|
|
||||||
return PaginatedRowsResult(
|
return IterrowsResponse(
|
||||||
rows=rows,
|
data=rows,
|
||||||
total_count=len(rows),
|
next_index=end if end < len(dataset_info.dataset_impl) else None,
|
||||||
next_page_token=str(end),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
|
|
@ -90,7 +90,7 @@ class MetaReferenceEvalImpl(
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
)
|
)
|
||||||
|
|
|
@ -328,7 +328,7 @@ class LoraFinetuningSingleDevice:
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
) -> Tuple[DistributedSampler, DataLoader]:
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
async def fetch_rows(dataset_id: str):
|
async def fetch_rows(dataset_id: str):
|
||||||
return await self.datasetio_api.get_rows_paginated(
|
return await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||||
|
RegexParserMathResponseScoringFn,
|
||||||
|
)
|
||||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||||
|
|
||||||
|
@ -82,7 +84,7 @@ class BasicScoringImpl(
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -167,7 +167,7 @@ class BraintrustScoringImpl(
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class LlmAsJudgeScoringImpl(
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import datasets as hf_datasets
|
import datasets as hf_datasets
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
@ -73,36 +73,27 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
await self.kvstore.delete(key=key)
|
await self.kvstore.delete(key=key)
|
||||||
del self.dataset_infos[dataset_id]
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
rows_in_page: int,
|
start_index: Optional[int] = None,
|
||||||
page_token: Optional[str] = None,
|
limit: Optional[int] = None,
|
||||||
filter_condition: Optional[str] = None,
|
) -> IterrowsResponse:
|
||||||
) -> PaginatedRowsResult:
|
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
loaded_dataset = load_hf_dataset(dataset_def)
|
||||||
|
|
||||||
if page_token and not page_token.isnumeric():
|
start_index = start_index or 0
|
||||||
raise ValueError("Invalid page_token")
|
|
||||||
|
|
||||||
if page_token is None or len(page_token) == 0:
|
if limit is None or limit == -1:
|
||||||
next_page_token = 0
|
|
||||||
else:
|
|
||||||
next_page_token = int(page_token)
|
|
||||||
|
|
||||||
start = next_page_token
|
|
||||||
if rows_in_page == -1:
|
|
||||||
end = len(loaded_dataset)
|
end = len(loaded_dataset)
|
||||||
else:
|
else:
|
||||||
end = min(start + rows_in_page, len(loaded_dataset))
|
end = min(start_index + limit, len(loaded_dataset))
|
||||||
|
|
||||||
rows = [loaded_dataset[i] for i in range(start, end)]
|
rows = [loaded_dataset[i] for i in range(start_index, end)]
|
||||||
|
|
||||||
return PaginatedRowsResult(
|
return IterrowsResponse(
|
||||||
rows=rows,
|
data=rows,
|
||||||
total_count=len(rows),
|
next_index=end if end < len(loaded_dataset) else None,
|
||||||
next_page_token=str(end),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue