From cb77712510708e9c9cbf397b7316f4e95c3c4fbf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 12 Nov 2024 11:42:49 -0500 Subject: [PATCH] fix eval tasks --- llama_stack/distribution/datatypes.py | 5 +++++ .../providers/adapters/datasetio/huggingface/huggingface.py | 1 + 2 files changed, 6 insertions(+) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 51b56dd5f..d0888b981 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -17,6 +17,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety @@ -36,6 +38,7 @@ RoutableObject = Union[ MemoryBank, Dataset, ScoringFn, + EvalTask, ] @@ -46,6 +49,7 @@ RoutableObjectWithProvider = Annotated[ MemoryBank, Dataset, ScoringFn, + EvalTask, ], Field(discriminator="type"), ] @@ -56,6 +60,7 @@ RoutedProtocol = Union[ Memory, DatasetIO, Scoring, + Eval, ] diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py index cd143a3ef..548866546 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py @@ -43,6 +43,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: Dataset, ) -> None: + print(f"Registering dataset: {dataset_def.identifier}") self.dataset_infos[dataset_def.identifier] = dataset_def async def get_rows_paginated(