From f429e75b3e36dab2622ced3792a9ed4d541773e4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 21:31:05 -0800 Subject: [PATCH] fix tests --- llama_stack/providers/tests/eval/test_eval.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 88f577cd8..5525e080e 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -15,13 +15,14 @@ from llama_stack.apis.eval.eval import ( EvalTaskDefWithProvider, ModelCandidate, ) +from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset # How to run this test: # # pytest llama_stack/providers/tests/eval/test_eval.py -# -m "meta_reference" +# -m "meta_reference_eval_together_inference_huggingface_datasetio" # -v -s --tb=short --disable-warnings @@ -46,7 +47,7 @@ class Testeval: datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) response = await datasets_impl.list_datasets() - assert len(response) == 1 + rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset_for_eval", rows_in_page=3, @@ -126,7 +127,11 @@ class Testeval: @pytest.mark.asyncio async def test_eval_run_benchmark_eval(self, eval_stack): - eval_impl, eval_tasks_impl, _, _, datasetio_impl, datasets_impl = eval_stack + eval_impl, eval_tasks_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasets], + ) response = await datasets_impl.list_datasets() assert len(response) > 0 if response[0].provider_id != "huggingface":