mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
merge
This commit is contained in:
commit
33b6d9b7b7
8 changed files with 67 additions and 304 deletions
|
@ -11,8 +11,7 @@ from llama_models.llama3.api import SamplingParams
|
|||
|
||||
from llama_stack.apis.eval.eval import (
|
||||
AppEvalTaskConfig,
|
||||
BenchmarkEvalTaskConfig,
|
||||
EvalTaskDef,
|
||||
EvalTaskDefWithProvider,
|
||||
ModelCandidate,
|
||||
)
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
@ -33,7 +32,7 @@ class Testeval:
|
|||
_, eval_tasks_impl, _, _, _, _ = eval_stack
|
||||
response = await eval_tasks_impl.list_eval_tasks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
assert len(response) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_evaluate_rows(self, eval_stack):
|
||||
|
@ -59,8 +58,17 @@ class Testeval:
|
|||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
"meta-reference::equality",
|
||||
]
|
||||
task_id = "meta-reference::app_eval"
|
||||
task_def = EvalTaskDefWithProvider(
|
||||
identifier=task_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
await eval_tasks_impl.register_eval_task(task_def)
|
||||
|
||||
response = await eval_impl.evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=AppEvalTaskConfig(
|
||||
|
@ -91,13 +99,16 @@ class Testeval:
|
|||
"meta-reference::subset_of",
|
||||
]
|
||||
|
||||
task_id = "meta-reference::app_eval-2"
|
||||
task_def = EvalTaskDefWithProvider(
|
||||
identifier=task_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
await eval_tasks_impl.register_eval_task(task_def)
|
||||
response = await eval_impl.run_eval(
|
||||
task=EvalTaskDef(
|
||||
# NOTE: this is needed to make the router work for all app evals
|
||||
identifier="meta-reference::app_eval",
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
),
|
||||
task_id=task_id,
|
||||
task_config=AppEvalTaskConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
|
@ -106,13 +117,9 @@ class Testeval:
|
|||
),
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
job_status = await eval_impl.job_status(task_id, response.job_id)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
eval_response = await eval_impl.job_result(task_id, response.job_id)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
providers:
|
||||
datasetio:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-braintrust
|
||||
provider_type: braintrust
|
||||
config: {}
|
||||
inference:
|
||||
- provider_id: tgi0
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:5009
|
|
@ -1,152 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.scoring, deps=[Api.datasetio, Api.inference]
|
||||
)
|
||||
return {
|
||||
"scoring_impl": impls[Api.scoring],
|
||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||
"datasets_impl": impls[Api.datasets],
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def provider_scoring_functions():
|
||||
return {
|
||||
"meta-reference": {
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
},
|
||||
"braintrust": {
|
||||
"braintrust::factuality",
|
||||
"braintrust::answer-correctness",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
# get current provider_type we're testing
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in function_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_register(scoring_settings):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
if provider_type not in ("meta-reference"):
|
||||
pytest.skip(
|
||||
"Other scoring providers don't support registering scoring functions."
|
||||
)
|
||||
|
||||
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
||||
# register the scoring function
|
||||
await scoring_functions_impl.register_scoring_function(
|
||||
ScoringFnDefWithProvider(
|
||||
identifier="meta-reference::llm_as_judge_8b_random",
|
||||
description="Llm As Judge Scoring Function",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
context=LLMAsJudgeContext(
|
||||
prompt_template=test_prompt,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regex=[r"Number: (\d+)"],
|
||||
),
|
||||
provider_id="test-meta",
|
||||
)
|
||||
)
|
||||
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
assert "meta-reference::llm_as_judge_8b_random" in function_ids
|
||||
|
||||
# test score using newly registered scoring function
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=[
|
||||
"meta-reference::llm_as_judge_8b_random",
|
||||
],
|
||||
)
|
||||
assert "meta-reference::llm_as_judge_8b_random" in response.results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||
)
|
||||
|
||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in response.results
|
Loading…
Add table
Add a link
Reference in a new issue