diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index f63d77421..ba2fc7c95 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -57,8 +57,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.eval_tasks[task_def.identifier] = task_def async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") expected_schemas = [ @@ -74,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): }, ] - if dataset_def.dataset_schema not in expected_schemas: + if dataset_def.schema not in expected_schemas: raise ValueError( f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 0728ce5d9..92c4d0331 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -11,8 +11,6 @@ from llama_models.llama3.api import SamplingParams, URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType -from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider - from llama_stack.apis.eval.eval import ( AppEvalTaskConfig, BenchmarkEvalTaskConfig, @@ -164,25 +162,22 @@ class Testeval: pytest.skip( "Only huggingface provider supports pre-registered remote datasets" ) - # register dataset - mmlu = DatasetDefWithProvider( - identifier="mmlu", - url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), - dataset_schema={ + + await datasets_impl.register_dataset( + dataset_id="mmlu", + schema={ "input_query": StringType(), "expected_answer": StringType(), "chat_completion_input": ChatCompletionInputType(), }, + url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), metadata={ "path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train", }, - provider_id="", ) - await datasets_impl.register_dataset(mmlu) - # register eval task await eval_tasks_impl.register_eval_task( eval_task_id="meta-reference-mmlu",