From ab777ef5cd919c73f77d9a7af8d3c5f03ab57098 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 19 Mar 2025 11:27:11 -0700 Subject: [PATCH] fix: fix open-benchmark template (#1695) ## What does this PR do? open-benchmark templated is broken after the datasets api refactor due to 2 reasons - provider_id and provider_resource_id are no longer needed - the type in run.yaml will be resolved as dict this PR is to fix the above 2 issues ## Test spin up a llama stack server successfully with llama stack run `llama_stack/templates/open-benchmark/run.yaml` --- llama_stack/apis/datasets/datasets.py | 2 -- llama_stack/distribution/routers/routing_tables.py | 8 ++++++++ llama_stack/templates/open-benchmark/open_benchmark.py | 5 ----- llama_stack/templates/open-benchmark/run.yaml | 5 ----- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 616371c7d..e2c940f64 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -121,8 +121,6 @@ class Dataset(CommonDatasetFields, Resource): class DatasetInput(CommonDatasetFields, BaseModel): dataset_id: str - provider_id: Optional[str] = None - provider_dataset_id: Optional[str] = None class ListDatasetsResponse(BaseModel): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5dea942f7..7aef2f8d5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -20,6 +20,8 @@ from llama_stack.apis.datasets import ( DatasetType, DataSource, ListDatasetsResponse, + RowsDataSource, + URIDataSource, ) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType @@ -377,6 +379,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + if not dataset_id: dataset_id = f"dataset-{str(uuid.uuid4())}" diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index b339e8c80..acfbd78d6 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -170,7 +170,6 @@ def get_distribution_template() -> DistributionTemplate: default_datasets = [ DatasetInput( dataset_id="simpleqa", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/simpleqa?split=train", @@ -178,7 +177,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="mmlu_cot", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all", @@ -186,7 +184,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="gpqa_cot", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main", @@ -194,7 +191,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="math_500", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/math_500?split=test", @@ -202,7 +198,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="bfcl", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/bfcl_v3?split=train", diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 93f437273..8dbf51472 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -164,35 +164,30 @@ datasets: uri: huggingface://datasets/llamastack/simpleqa?split=train metadata: {} dataset_id: simpleqa - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all metadata: {} dataset_id: mmlu_cot - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main metadata: {} dataset_id: gpqa_cot - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/math_500?split=test metadata: {} dataset_id: math_500 - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/bfcl_v3?split=train metadata: {} dataset_id: bfcl - provider_id: huggingface scoring_fns: [] benchmarks: - dataset_id: simpleqa