fix tests

This commit is contained in:
Xi Yan 2024-11-07 21:31:05 -08:00
parent 0443b36cc1
commit f429e75b3e

View file

@ -15,13 +15,14 @@ from llama_stack.apis.eval.eval import (
EvalTaskDefWithProvider, EvalTaskDefWithProvider,
ModelCandidate, ModelCandidate,
) )
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/eval/test_eval.py # pytest llama_stack/providers/tests/eval/test_eval.py
# -m "meta_reference" # -m "meta_reference_eval_together_inference_huggingface_datasetio"
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings
@ -46,7 +47,7 @@ class Testeval:
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
) )
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
rows_in_page=3, rows_in_page=3,
@ -126,7 +127,11 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack): async def test_eval_run_benchmark_eval(self, eval_stack):
eval_impl, eval_tasks_impl, _, _, datasetio_impl, datasets_impl = eval_stack eval_impl, eval_tasks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) > 0 assert len(response) > 0
if response[0].provider_id != "huggingface": if response[0].provider_id != "huggingface":