equality scorer

This commit is contained in:
Xi Yan 2024-10-23 16:07:17 -07:00
parent cad8c8710b
commit 4b1d7da030
5 changed files with 53 additions and 23 deletions

View file

@ -14,7 +14,6 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any] ScoringResult = Dict[str, Any]
SingleScoringResult = Dict[str, Any]
@json_schema_type @json_schema_type
@ -25,7 +24,7 @@ class ScoreBatchResponse(BaseModel):
@json_schema_type @json_schema_type
class ScoreResponse(BaseModel): class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]] results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol): class ScoringFunctionStore(Protocol):

View file

@ -217,6 +217,7 @@ class ScoringRouter(Scoring):
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse: ) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions: for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(
@ -225,6 +226,6 @@ class ScoringRouter(Scoring):
input_rows=input_rows, input_rows=input_rows,
scoring_functions=[fn_identifier], scoring_functions=[fn_identifier],
) )
print( res.update(score_response.results)
f"fn_identifier={fn_identifier}, score_response={score_response}",
) return ScoreResponse(results=res)

View file

@ -17,6 +17,8 @@ class BaseScorer(ABC):
- aggregate(self, scorer_results) - aggregate(self, scorer_results)
""" """
scoring_function_def: DeterministicFunctionDef
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View file

@ -7,6 +7,9 @@
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
BaseScorer, BaseScorer,
) )
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
class EqualityScorer(BaseScorer): class EqualityScorer(BaseScorer):
@ -14,11 +17,28 @@ class EqualityScorer(BaseScorer):
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
""" """
def __init__(self, target: str) -> None: scoring_function_def = DeterministicFunctionDef(
""" identifier="equality",
Initialize the EqualityScorer with a target string. description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
Args: def score_row(self, input_row: Dict[str, Any]) -> ScoringResult:
target (str): The target string to match against. expected_answer = input_row["expected_answer"]
""" generated_answer = input_row["generated_answer"]
self.target = target score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult:
assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -14,9 +14,18 @@ from llama_stack.apis.datasetio import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer,
)
from .config import MetaReferenceScoringConfig from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [
EqualityScorer,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__( def __init__(
@ -31,14 +40,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
return [ return [x.scoring_function_def for x in SUPPORTED_SCORERS]
DeterministicFunctionDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
]
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
raise NotImplementedError( raise NotImplementedError(
@ -53,7 +55,13 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse: ) -> ScoreResponse:
print( res = {}
f"scoring input_rows {input_rows} on scoring_functions {scoring_functions}" for scoring_fn_id in scoring_functions:
scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results)
res[scoring_fn_id] = agg_results
return ScoreResponse(
results=res,
) )
return ScoreResponse()