diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 54d8f7487..adac34d55 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -13,7 +13,15 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 -ScoringResult = Dict[str, Any] +# mapping of metric to value +ScoringResultRow = Dict[str, Any] + + +@json_schema_type +class ScoringResult(BaseModel): + score_rows: List[ScoringResultRow] + # aggregated metrics to value + aggregated_results: Dict[str, Any] @json_schema_type diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py index 39d040e1c..9c982948e 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py @@ -26,12 +26,12 @@ class BaseScorer(ABC): return self.__class__.__name__ @abstractmethod - def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: raise NotImplementedError() @abstractmethod - def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: raise NotImplementedError() - def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResult]: + def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: return [self.score_row(input_row) for input_row in input_rows] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py index e79a96068..b9b8b1eee 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -24,7 +24,7 @@ class EqualityScorer(BaseScorer): return_type=NumberType(), ) - def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -37,7 +37,7 @@ class EqualityScorer(BaseScorer): "score": score, } - def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: assert len(scoring_results) > 0, "Empty scoring results provided." num_correct = sum(result["score"] for result in scoring_results) avg_score = num_correct / len(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 73f9fcc5a..0d32c8195 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -99,7 +99,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): scorer = SCORER_REGISTRY[scoring_fn_id]() score_results = scorer.score(input_rows) agg_results = scorer.aggregate(score_results) - res[scoring_fn_id] = agg_results + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) return ScoreResponse( results=res, diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 1af5c05cf..2218faa54 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -66,3 +66,4 @@ async def test_scoring_score(scoring_settings): ) assert len(response.results) == 1 + assert "equality" in response.results