mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
scoring
This commit is contained in:
parent
13720cbedf
commit
1094f26426
3 changed files with 13 additions and 5 deletions
|
@ -48,7 +48,7 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
|
@ -56,5 +56,5 @@ class Scoring(Protocol):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||||
) -> ScoreResponse: ...
|
) -> ScoreResponse: ...
|
||||||
|
|
|
@ -99,7 +99,7 @@ class BraintrustScoringImpl(
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
await self.set_api_key()
|
await self.set_api_key()
|
||||||
|
@ -135,7 +135,9 @@ class BraintrustScoringImpl(
|
||||||
return {"score": score, "metadata": result.metadata}
|
return {"score": score, "metadata": result.metadata}
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
self,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
await self.set_api_key()
|
await self.set_api_key()
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -151,6 +153,12 @@ class BraintrustScoringImpl(
|
||||||
scoring_fn_id
|
scoring_fn_id
|
||||||
].params.aggregation_functions
|
].params.aggregation_functions
|
||||||
|
|
||||||
|
# override scoring_fn params if provided
|
||||||
|
if scoring_functions[scoring_fn_id] is not None:
|
||||||
|
override_params = scoring_functions[scoring_fn_id]
|
||||||
|
if override_params.aggregation_functions:
|
||||||
|
aggregation_functions = override_params.aggregation_functions
|
||||||
|
|
||||||
agg_results = aggregate_metrics(score_results, aggregation_functions)
|
agg_results = aggregate_metrics(score_results, aggregation_functions)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
|
|
|
@ -197,7 +197,7 @@ class TestScoring:
|
||||||
judge_score_regexes=[r"Score: (\d+)"],
|
judge_score_regexes=[r"Score: (\d+)"],
|
||||||
aggregation_functions=aggr_fns,
|
aggregation_functions=aggr_fns,
|
||||||
)
|
)
|
||||||
elif x.provider_id == "basic":
|
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||||
if "regex_parser" in x.identifier:
|
if "regex_parser" in x.identifier:
|
||||||
scoring_functions[x.identifier] = RegexParserScoringFnParams(
|
scoring_functions[x.identifier] = RegexParserScoringFnParams(
|
||||||
aggregation_functions=aggr_fns,
|
aggregation_functions=aggr_fns,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue