fix: fix open-benchmark template (#1695)

## What does this PR do?
open-benchmark templated is broken after the datasets api refactor due
to 2 reasons
- provider_id and provider_resource_id are no longer needed 
- the type in run.yaml will be resolved as dict

this PR is to fix the above 2 issues 

## Test 
spin up a llama stack server successfully with llama stack run
`llama_stack/templates/open-benchmark/run.yaml`
This commit is contained in:
Botao Chen 2025-03-19 11:27:11 -07:00 committed by GitHub
parent 6949bd1999
commit ab777ef5cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 8 additions and 12 deletions

View file

@ -121,8 +121,6 @@ class Dataset(CommonDatasetFields, Resource):
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class ListDatasetsResponse(BaseModel):

View file

@ -20,6 +20,8 @@ from llama_stack.apis.datasets import (
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
@ -377,6 +379,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"

View file

@ -170,7 +170,6 @@ def get_distribution_template() -> DistributionTemplate:
default_datasets = [
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/simpleqa?split=train",
@ -178,7 +177,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all",
@ -186,7 +184,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
@ -194,7 +191,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/math_500?split=test",
@ -202,7 +198,6 @@ def get_distribution_template() -> DistributionTemplate:
),
DatasetInput(
dataset_id="bfcl",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",

View file

@ -164,35 +164,30 @@ datasets:
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
provider_id: huggingface
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
provider_id: huggingface
scoring_fns: []
benchmarks:
- dataset_id: simpleqa