mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
fix scoring test
This commit is contained in:
parent
e27c6e3662
commit
68a4e6d00e
3 changed files with 41 additions and 14 deletions
|
@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"scoring": "meta_reference",
|
||||
"datasetio": "meta_reference",
|
||||
"datasetio": "localfs",
|
||||
"inference": "fireworks",
|
||||
},
|
||||
id="meta_reference_scoring_fireworks_inference",
|
||||
|
@ -25,7 +25,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"scoring": "meta_reference",
|
||||
"datasetio": "meta_reference",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="meta_reference_scoring_together_inference",
|
||||
|
|
|
@ -52,9 +52,4 @@ async def scoring_stack(request):
|
|||
provider_data,
|
||||
)
|
||||
|
||||
return (
|
||||
impls[Api.scoring],
|
||||
impls[Api.scoring_functions],
|
||||
impls[Api.datasetio],
|
||||
impls[Api.datasets],
|
||||
)
|
||||
return impls
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
||||
# How to run this test:
|
||||
|
@ -23,20 +23,36 @@ class TestScoring:
|
|||
async def test_scoring_functions_list(self, scoring_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, scoring_functions_impl, _, _ = scoring_stack
|
||||
scoring_functions_impl = scoring_stack[Api.scoring_functions]
|
||||
response = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(self, scoring_stack):
|
||||
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
|
||||
scoring_stack
|
||||
(
|
||||
scoring_impl,
|
||||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
models_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
|
||||
await models_impl.register_model(
|
||||
model_id=model_id,
|
||||
provider_id="",
|
||||
)
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
@ -69,13 +85,29 @@ class TestScoring:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score_with_params(self, scoring_stack):
|
||||
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
|
||||
scoring_stack
|
||||
(
|
||||
scoring_impl,
|
||||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
models_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
for model_id in ["Llama3.1-405B-Instruct"]:
|
||||
await models_impl.register_model(
|
||||
model_id=model_id,
|
||||
provider_id="",
|
||||
)
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue