mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
equality scorer
This commit is contained in:
parent
cad8c8710b
commit
4b1d7da030
5 changed files with 53 additions and 23 deletions
|
@ -14,7 +14,6 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
|||
|
||||
|
||||
ScoringResult = Dict[str, Any]
|
||||
SingleScoringResult = Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -25,7 +24,7 @@ class ScoreBatchResponse(BaseModel):
|
|||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
# each key in the dict is a scoring function name
|
||||
results: List[Dict[str, ScoringResult]]
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
|
|
|
@ -217,6 +217,7 @@ class ScoringRouter(Scoring):
|
|||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions:
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
|
@ -225,6 +226,6 @@ class ScoringRouter(Scoring):
|
|||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
)
|
||||
print(
|
||||
f"fn_identifier={fn_identifier}, score_response={score_response}",
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
return ScoreResponse(results=res)
|
||||
|
|
|
@ -17,6 +17,8 @@ class BaseScorer(ABC):
|
|||
- aggregate(self, scorer_results)
|
||||
"""
|
||||
|
||||
scoring_function_def: DeterministicFunctionDef
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
|
||||
BaseScorer,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
|
||||
class EqualityScorer(BaseScorer):
|
||||
|
@ -14,11 +17,28 @@ class EqualityScorer(BaseScorer):
|
|||
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, target: str) -> None:
|
||||
"""
|
||||
Initialize the EqualityScorer with a target string.
|
||||
scoring_function_def = DeterministicFunctionDef(
|
||||
identifier="equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
Args:
|
||||
target (str): The target string to match against.
|
||||
"""
|
||||
self.target = target
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResult:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult:
|
||||
assert len(scoring_results) > 0, "Empty scoring results provided."
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
||||
|
|
|
@ -14,9 +14,18 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
||||
EqualityScorer,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
SUPPORTED_SCORERS = [
|
||||
EqualityScorer,
|
||||
]
|
||||
|
||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
|
||||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(
|
||||
|
@ -31,14 +40,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
|
||||
return [
|
||||
DeterministicFunctionDef(
|
||||
identifier="equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
]
|
||||
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
|
||||
raise NotImplementedError(
|
||||
|
@ -53,7 +55,13 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
print(
|
||||
f"scoring input_rows {input_rows} on scoring_functions {scoring_functions}"
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
scorer = SCORER_REGISTRY[scoring_fn_id]()
|
||||
score_results = scorer.score(input_rows)
|
||||
agg_results = scorer.aggregate(score_results)
|
||||
res[scoring_fn_id] = agg_results
|
||||
|
||||
return ScoreResponse(
|
||||
results=res,
|
||||
)
|
||||
return ScoreResponse()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue