mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
scoring
This commit is contained in:
parent
13720cbedf
commit
1094f26426
3 changed files with 13 additions and 5 deletions
|
@ -48,7 +48,7 @@ class Scoring(Protocol):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
|
@ -56,5 +56,5 @@ class Scoring(Protocol):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse: ...
|
||||
|
|
|
@ -99,7 +99,7 @@ class BraintrustScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
|
@ -135,7 +135,9 @@ class BraintrustScoringImpl(
|
|||
return {"score": score, "metadata": result.metadata}
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
|
@ -151,6 +153,12 @@ class BraintrustScoringImpl(
|
|||
scoring_fn_id
|
||||
].params.aggregation_functions
|
||||
|
||||
# override scoring_fn params if provided
|
||||
if scoring_functions[scoring_fn_id] is not None:
|
||||
override_params = scoring_functions[scoring_fn_id]
|
||||
if override_params.aggregation_functions:
|
||||
aggregation_functions = override_params.aggregation_functions
|
||||
|
||||
agg_results = aggregate_metrics(score_results, aggregation_functions)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
|
|
|
@ -197,7 +197,7 @@ class TestScoring:
|
|||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic":
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = RegexParserScoringFnParams(
|
||||
aggregation_functions=aggr_fns,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue