From 24dce9cb7a916491ee2ccbc96a7bc86c0ae78562 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 12:08:57 -0700 Subject: [PATCH] minor typing --- .../impls/meta_reference/scoring/scorer/inclusion_scorer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py index 506bc60a7..d317bfba4 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py @@ -24,7 +24,7 @@ class InclusionScorer(BaseScorer): return_type=NumberType(), ) - def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -37,7 +37,7 @@ class InclusionScorer(BaseScorer): "score": score, } - def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: assert len(scoring_results) > 0, "Empty scoring results provided." num_correct = sum(result["score"] for result in scoring_results) avg_score = num_correct / len(scoring_results)