datasetio test fix

This commit is contained in:
Xi Yan 2024-10-27 16:43:02 -07:00
parent d3d2243dfb
commit 68346fac39
6 changed files with 100 additions and 46 deletions

View file

@ -11,6 +11,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import BraintrustScoringConfig
@ -28,29 +30,29 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
self.datasets_api = datasets_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None: ...
# for x in FIXED_FNS:
# impl = x()
# await impl.initialize()
# for fn_defs in impl.get_supported_scoring_fn_defs():
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
# for x in LLM_JUDGE_FNS:
# impl = x(inference_api=self.inference_api)
# await impl.initialize()
# for fn_defs in impl.get_supported_scoring_fn_defs():
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
# self.llm_as_judge_fn = impl
async def initialize(self) -> None:
self.scoring_fn_id_impls = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
}
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
return []
# return [
# fn_defs
# for impl in self.scoring_fn_id_impls.values()
# for fn_defs in impl.get_supported_scoring_fn_defs()
# ]
return [
ScoringFnDef(
identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value.",
parameters=[],
return_type=NumberType(),
),
ScoringFnDef(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value.",
parameters=[],
return_type=NumberType(),
),
]
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
# self.llm_as_judge_fn.register_scoring_fn_def(function_def)
@ -81,7 +83,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
print("score_batch")
# await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
# all_rows = await self.datasetio_api.get_rows_paginated(
# dataset_id=dataset_id,
# rows_in_page=-1,
@ -103,16 +105,17 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) -> ScoreResponse:
res = {}
print("score")
# for scoring_fn_id in scoring_functions:
# if scoring_fn_id not in self.scoring_fn_id_impls:
# raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
# agg_results = await scoring_fn.aggregate(score_results)
# res[scoring_fn_id] = ScoringResult(
# score_rows=score_results,
# aggregated_results=agg_results,
# )
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
# agg_results = await scoring_fn.aggregate(score_results)
# res[scoring_fn_id] = ScoringResult(
# score_rows=score_results,
# aggregated_results=agg_results,
# )
return ScoreResponse(
results=res,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,10 @@
{
"identifier": "braintrust::factuality",
"description": "Test whether an output is factual, compared to an original (`expected`) value. ",
"metadata": {},
"parameters": [],
"return_type": {
"type": "number"
},
"context": null
}