fix after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:20:19 -08:00
parent 48dcea18f2
commit cab15c9cfa
2 changed files with 8 additions and 13 deletions

View file

@ -57,8 +57,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [ expected_schemas = [
@ -74,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
}, },
] ]
if dataset_def.dataset_schema not in expected_schemas: if dataset_def.schema not in expected_schemas:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
) )

View file

@ -11,8 +11,6 @@ from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
from llama_stack.apis.eval.eval import ( from llama_stack.apis.eval.eval import (
AppEvalTaskConfig, AppEvalTaskConfig,
BenchmarkEvalTaskConfig, BenchmarkEvalTaskConfig,
@ -164,25 +162,22 @@ class Testeval:
pytest.skip( pytest.skip(
"Only huggingface provider supports pre-registered remote datasets" "Only huggingface provider supports pre-registered remote datasets"
) )
# register dataset
mmlu = DatasetDefWithProvider( await datasets_impl.register_dataset(
identifier="mmlu", dataset_id="mmlu",
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), schema={
dataset_schema={
"input_query": StringType(), "input_query": StringType(),
"expected_answer": StringType(), "expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(), "chat_completion_input": ChatCompletionInputType(),
}, },
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
metadata={ metadata={
"path": "llamastack/evals", "path": "llamastack/evals",
"name": "evals__mmlu__details", "name": "evals__mmlu__details",
"split": "train", "split": "train",
}, },
provider_id="",
) )
await datasets_impl.register_dataset(mmlu)
# register eval task # register eval task
await eval_tasks_impl.register_eval_task( await eval_tasks_impl.register_eval_task(
eval_task_id="meta-reference-mmlu", eval_task_id="meta-reference-mmlu",