diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index de3881e89..edb36c75e 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -14,7 +14,6 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 ScoringResult = Dict[str, Any] -SingleScoringResult = Dict[str, Any] @json_schema_type @@ -25,7 +24,7 @@ class ScoreBatchResponse(BaseModel): @json_schema_type class ScoreResponse(BaseModel): # each key in the dict is a scoring function name - results: List[Dict[str, ScoringResult]] + results: Dict[str, ScoringResult] class ScoringFunctionStore(Protocol): diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 168cf9235..c85ba47d0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -217,6 +217,7 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: + res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions: score_response = await self.routing_table.get_provider_impl( @@ -225,6 +226,6 @@ class ScoringRouter(Scoring): input_rows=input_rows, scoring_functions=[fn_identifier], ) - print( - f"fn_identifier={fn_identifier}, score_response={score_response}", - ) + res.update(score_response.results) + + return ScoreResponse(results=res) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py index 5f35f2ddd..39d040e1c 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py @@ -17,6 +17,8 @@ class BaseScorer(ABC): - aggregate(self, scorer_results) """ + scoring_function_def: DeterministicFunctionDef + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py index 82ece9ebf..9fd3716b7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -7,6 +7,9 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( BaseScorer, ) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 class EqualityScorer(BaseScorer): @@ -14,11 +17,28 @@ class EqualityScorer(BaseScorer): A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ - def __init__(self, target: str) -> None: - """ - Initialize the EqualityScorer with a target string. + scoring_function_def = DeterministicFunctionDef( + identifier="equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), + ) - Args: - target (str): The target string to match against. - """ - self.target = target + def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + score = 1.0 if expected_answer == generated_answer else 0.0 + return { + "score": score, + } + + def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + assert len(scoring_results) > 0, "Empty scoring results provided." + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 895d74c53..17763413d 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -14,9 +14,18 @@ from llama_stack.apis.datasetio import * # noqa: F403 from termcolor import cprint from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( + EqualityScorer, +) from .config import MetaReferenceScoringConfig +SUPPORTED_SCORERS = [ + EqualityScorer, +] + +SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} + class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( @@ -31,14 +40,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFunctionDef]: - return [ - DeterministicFunctionDef( - identifier="equality", - description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], - return_type=NumberType(), - ) - ] + return [x.scoring_function_def for x in SUPPORTED_SCORERS] async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: raise NotImplementedError( @@ -53,7 +55,13 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: - print( - f"scoring input_rows {input_rows} on scoring_functions {scoring_functions}" + res = {} + for scoring_fn_id in scoring_functions: + scorer = SCORER_REGISTRY[scoring_fn_id]() + score_results = scorer.score(input_rows) + agg_results = scorer.aggregate(score_results) + res[scoring_fn_id] = agg_results + + return ScoreResponse( + results=res, ) - return ScoreResponse()