diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py index 506bc60a7..d317bfba4 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py @@ -24,7 +24,7 @@ class InclusionScorer(BaseScorer): return_type=NumberType(), ) - def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -37,7 +37,7 @@ class InclusionScorer(BaseScorer): "score": score, } - def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: assert len(scoring_results) > 0, "Empty scoring results provided." num_correct = sum(result["score"] for result in scoring_results) avg_score = num_correct / len(scoring_results)