mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
datasetio test fix
This commit is contained in:
parent
d3d2243dfb
commit
68346fac39
6 changed files with 100 additions and 46 deletions
|
@ -2,4 +2,4 @@ include requirements.txt
|
|||
include llama_stack/distribution/*.sh
|
||||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/templates/*/build.yaml
|
||||
include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json
|
||||
include llama_stack/providers/impls/*/scoring/scoring_fn/fn_defs/*.json
|
||||
|
|
|
@ -11,6 +11,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
@ -28,29 +30,29 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
self.datasets_api = datasets_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
# for x in FIXED_FNS:
|
||||
# impl = x()
|
||||
# await impl.initialize()
|
||||
# for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
# for x in LLM_JUDGE_FNS:
|
||||
# impl = x(inference_api=self.inference_api)
|
||||
# await impl.initialize()
|
||||
# for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
# self.llm_as_judge_fn = impl
|
||||
async def initialize(self) -> None:
|
||||
self.scoring_fn_id_impls = {
|
||||
"braintrust::factuality": Factuality(),
|
||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||
}
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
return []
|
||||
# return [
|
||||
# fn_defs
|
||||
# for impl in self.scoring_fn_id_impls.values()
|
||||
# for fn_defs in impl.get_supported_scoring_fn_defs()
|
||||
# ]
|
||||
return [
|
||||
ScoringFnDef(
|
||||
identifier="braintrust::factuality",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
),
|
||||
ScoringFnDef(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
),
|
||||
]
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
# self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
@ -81,7 +83,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
print("score_batch")
|
||||
# await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
# all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
# dataset_id=dataset_id,
|
||||
# rows_in_page=-1,
|
||||
|
@ -103,16 +105,17 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
) -> ScoreResponse:
|
||||
res = {}
|
||||
print("score")
|
||||
# for scoring_fn_id in scoring_functions:
|
||||
# if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
# raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
# agg_results = await scoring_fn.aggregate(score_results)
|
||||
# res[scoring_fn_id] = ScoringResult(
|
||||
# score_rows=score_results,
|
||||
# aggregated_results=agg_results,
|
||||
# )
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
|
||||
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
# agg_results = await scoring_fn.aggregate(score_results)
|
||||
# res[scoring_fn_id] = ScoringResult(
|
||||
# score_rows=score_results,
|
||||
# aggregated_results=agg_results,
|
||||
# )
|
||||
|
||||
return ScoreResponse(
|
||||
results=res,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"identifier": "braintrust::factuality",
|
||||
"description": "Test whether an output is factual, compared to an original (`expected`) value. ",
|
||||
"metadata": {},
|
||||
"parameters": [],
|
||||
"return_type": {
|
||||
"type": "number"
|
||||
},
|
||||
"context": null
|
||||
}
|
|
@ -82,7 +82,7 @@ async def register_dataset(
|
|||
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier=dataset_id,
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
provider_id=os.environ["PROVIDER_ID"] or os.environ["DATASETIO_PROVIDER_ID"],
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
|
|
|
@ -43,16 +43,35 @@ async def scoring_settings():
|
|||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def provider_scoring_functions():
|
||||
return {
|
||||
"meta-reference": {
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
},
|
||||
"braintrust": {
|
||||
"braintrust::factuality",
|
||||
"braintrust::answer-correctness",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_list(scoring_settings):
|
||||
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
assert "meta-reference::equality" in function_ids
|
||||
assert "meta-reference::subset_of" in function_ids
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
||||
# get current provider_type we're testing
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in function_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings):
|
|||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
if provider_type not in ("meta-reference"):
|
||||
pytest.skip(
|
||||
"Other scoring providers don't support registering scoring functions."
|
||||
)
|
||||
|
||||
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
||||
# register the scoring function
|
||||
await scoring_functions_impl.register_scoring_function(
|
||||
|
@ -97,24 +127,30 @@ async def test_scoring_functions_register(scoring_settings):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(scoring_settings):
|
||||
async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
if provider_type not in ("meta-reference"):
|
||||
pytest.skip(
|
||||
"Other scoring providers don't support registering scoring functions."
|
||||
)
|
||||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=[
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
],
|
||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||
)
|
||||
|
||||
assert len(response.results) == 3
|
||||
assert "meta-reference::equality" in response.results
|
||||
assert "meta-reference::subset_of" in response.results
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in response.results
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue