This commit is contained in:
Xi Yan 2024-12-19 14:26:58 -08:00
parent 13720cbedf
commit 1094f26426
3 changed files with 13 additions and 5 deletions

View file

@ -48,7 +48,7 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@ -56,5 +56,5 @@ class Scoring(Protocol):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -99,7 +99,7 @@ class BraintrustScoringImpl(
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.set_api_key() await self.set_api_key()
@ -135,7 +135,9 @@ class BraintrustScoringImpl(
return {"score": score, "metadata": result.metadata} return {"score": score, "metadata": result.metadata}
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ) -> ScoreResponse:
await self.set_api_key() await self.set_api_key()
res = {} res = {}
@ -151,6 +153,12 @@ class BraintrustScoringImpl(
scoring_fn_id scoring_fn_id
].params.aggregation_functions ].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions
agg_results = aggregate_metrics(score_results, aggregation_functions) agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,

View file

@ -197,7 +197,7 @@ class TestScoring:
judge_score_regexes=[r"Score: (\d+)"], judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,
) )
elif x.provider_id == "basic": elif x.provider_id == "basic" or x.provider_id == "braintrust":
if "regex_parser" in x.identifier: if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams( scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,