From 1094f264267bc5352a0d98097a86b8c228dfe167 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Dec 2024 14:26:58 -0800 Subject: [PATCH] scoring --- llama_stack/apis/scoring/scoring.py | 4 ++-- .../inline/scoring/braintrust/braintrust.py | 12 ++++++++++-- llama_stack/providers/tests/scoring/test_scoring.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a47620a3d..e062d0f57 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,7 +48,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]], save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -56,5 +56,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]], ) -> ScoreResponse: ... diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index fcb48fd33..7a966e67a 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -99,7 +99,7 @@ class BraintrustScoringImpl( async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]], save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.set_api_key() @@ -135,7 +135,9 @@ class BraintrustScoringImpl( return {"score": score, "metadata": result.metadata} async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]], ) -> ScoreResponse: await self.set_api_key() res = {} @@ -151,6 +153,12 @@ class BraintrustScoringImpl( scoring_fn_id ].params.aggregation_functions + # override scoring_fn params if provided + if scoring_functions[scoring_fn_id] is not None: + override_params = scoring_functions[scoring_fn_id] + if override_params.aggregation_functions: + aggregation_functions = override_params.aggregation_functions + agg_results = aggregate_metrics(score_results, aggregation_functions) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index dce069df0..2643b8fd6 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -197,7 +197,7 @@ class TestScoring: judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns, ) - elif x.provider_id == "basic": + elif x.provider_id == "basic" or x.provider_id == "braintrust": if "regex_parser" in x.identifier: scoring_functions[x.identifier] = RegexParserScoringFnParams( aggregation_functions=aggr_fns,