fix after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:20:19 -08:00
parent 48dcea18f2
commit cab15c9cfa
2 changed files with 8 additions and 13 deletions

View file

@ -57,8 +57,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
@ -74,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
},
]
if dataset_def.dataset_schema not in expected_schemas:
if dataset_def.schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)

View file

@ -11,8 +11,6 @@ from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
@ -164,25 +162,22 @@ class Testeval:
pytest.skip(
"Only huggingface provider supports pre-registered remote datasets"
)
# register dataset
mmlu = DatasetDefWithProvider(
identifier="mmlu",
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
dataset_schema={
await datasets_impl.register_dataset(
dataset_id="mmlu",
schema={
"input_query": StringType(),
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
},
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
metadata={
"path": "llamastack/evals",
"name": "evals__mmlu__details",
"split": "train",
},
provider_id="",
)
await datasets_impl.register_dataset(mmlu)
# register eval task
await eval_tasks_impl.register_eval_task(
eval_task_id="meta-reference-mmlu",