[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* add all rows scores to ScoringResult

* bugfix

* scoring function def rename
This commit is contained in:
Xi Yan 2024-10-24 14:52:30 -07:00 committed by GitHub
parent e70420a06e
commit cb84034567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 904 additions and 51 deletions

View file

@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any]
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]]
results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
@ -37,7 +46,10 @@ class Scoring(Protocol):
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")