From 0dc2b9c6e4b496176cec347fcedaacedef337442 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 21:22:18 -0800 Subject: [PATCH] fixture return impl --- llama_stack/providers/tests/eval/fixtures.py | 9 +-------- llama_stack/providers/tests/eval/test_eval.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 22181f3b2..810239440 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -52,11 +52,4 @@ async def eval_stack(request): provider_data, ) - return ( - impls[Api.eval], - impls[Api.eval_tasks], - impls[Api.scoring], - impls[Api.scoring_functions], - impls[Api.datasetio], - impls[Api.datasets], - ) + return impls diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 242be50ca..a55a754c5 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -14,6 +14,7 @@ from llama_stack.apis.eval.eval import ( EvalTaskDefWithProvider, ModelCandidate, ) +from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset @@ -29,14 +30,19 @@ class Testeval: async def test_eval_tasks_list(self, eval_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, eval_tasks_impl, _, _, _, _ = eval_stack + eval_tasks_impl = eval_stack[Api.eval_tasks] response = await eval_tasks_impl.list_eval_tasks() assert isinstance(response, list) assert len(response) == 0 @pytest.mark.asyncio async def test_eval_evaluate_rows(self, eval_stack): - eval_impl, eval_tasks_impl, _, _, datasetio_impl, datasets_impl = eval_stack + eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasetio], + eval_stack[Api.datasets], + ) await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -78,7 +84,11 @@ class Testeval: @pytest.mark.asyncio async def test_eval_run_eval(self, eval_stack): - eval_impl, eval_tasks_impl, _, _, datasetio_impl, datasets_impl = eval_stack + eval_impl, eval_tasks_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasets], + ) await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" )