add all rows scores to ScoringResult

This commit is contained in:
Xi Yan 2024-10-24 11:53:15 -07:00
parent 071dba8871
commit a3a8f32541
5 changed files with 19 additions and 7 deletions

View file

@ -13,7 +13,15 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any] # mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type @json_schema_type

View file

@ -26,12 +26,12 @@ class BaseScorer(ABC):
return self.__class__.__name__ return self.__class__.__name__
@abstractmethod @abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
raise NotImplementedError() raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResult]: def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
return [self.score_row(input_row) for input_row in input_rows] return [self.score_row(input_row) for input_row in input_rows]

View file

@ -24,7 +24,7 @@ class EqualityScorer(BaseScorer):
return_type=NumberType(), return_type=NumberType(),
) )
def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert (
"generated_answer" in input_row "generated_answer" in input_row
@ -37,7 +37,7 @@ class EqualityScorer(BaseScorer):
"score": score, "score": score,
} }
def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided." assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results) num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results) avg_score = num_correct / len(scoring_results)

View file

@ -99,7 +99,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
scorer = SCORER_REGISTRY[scoring_fn_id]() scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows) score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results) agg_results = scorer.aggregate(score_results)
res[scoring_fn_id] = agg_results res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse( return ScoreResponse(
results=res, results=res,

View file

@ -66,3 +66,4 @@ async def test_scoring_score(scoring_settings):
) )
assert len(response.results) == 1 assert len(response.results) == 1
assert "equality" in response.results