add register model to unit test

This commit is contained in:
Xi Yan 2024-11-11 10:35:59 -05:00
parent e690eb7ad3
commit 1031f1404b
8 changed files with 23 additions and 89 deletions

View file

@ -23,8 +23,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.datasetio, api=Api.datasetio,
provider_type="huggingface", provider_type="huggingface",
pip_packages=["datasets"], pip_packages=["datasets"],
module="llama_stack.providers.inline.huggingface.datasetio", module="llama_stack.providers.adapters.datasetio.huggingface",
config_class="llama_stack.providers.inline.huggingface.datasetio.HuggingfaceDatasetIOConfig", config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig",
api_dependencies=[], api_dependencies=[],
), ),
] ]

View file

@ -37,12 +37,18 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack): async def test_eval_evaluate_rows(self, eval_stack):
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = ( eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.eval_tasks], eval_stack[Api.eval_tasks],
eval_stack[Api.datasetio], eval_stack[Api.datasetio],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset( await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
) )
@ -66,7 +72,6 @@ class Testeval:
provider_id="meta-reference", provider_id="meta-reference",
) )
await eval_tasks_impl.register_eval_task(task_def) await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.evaluate_rows( response = await eval_impl.evaluate_rows(
task_id=task_id, task_id=task_id,
input_rows=rows.rows, input_rows=rows.rows,
@ -84,11 +89,17 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack): async def test_eval_run_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl = ( eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.eval_tasks], eval_stack[Api.eval_tasks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset( await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
) )
@ -127,11 +138,17 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack): async def test_eval_run_benchmark_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl = ( eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.eval_tasks], eval_stack[Api.eval_tasks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) > 0 assert len(response) > 0
if response[0].provider_id != "huggingface": if response[0].provider_id != "huggingface":

View file

@ -1,83 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.eval.eval import ModelCandidate
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_models.llama3.api import SamplingParams
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def eval_settings():
impls = await resolve_impls_for_test(
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference]
)
return {
"eval_impl": impls[Api.eval],
"scoring_impl": impls[Api.scoring],
"datasets_impl": impls[Api.datasets],
}
@pytest.mark.asyncio
async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"]
await register_dataset(
datasets_impl,
for_generation=True,
dataset_id="test_dataset_for_eval",
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
eval_impl = eval_settings["eval_impl"]
response = await eval_impl.evaluate_batch(
dataset_id=response[0].identifier,
candidate=ModelCandidate(
model="Llama3.2-1B-Instruct",
sampling_params=SamplingParams(),
),
scoring_functions=[
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores