forked from phoenix-oss/llama-stack-mirror
[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)
* wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * add all rows scores to ScoringResult * bugfix * scoring function def rename
This commit is contained in:
parent
e70420a06e
commit
cb84034567
28 changed files with 904 additions and 51 deletions
|
@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
|||
await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for d in datasets:
|
||||
d.provider_id = pid
|
||||
|
||||
add_objects(datasets)
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(
|
||||
[
|
||||
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in scoring_functions
|
||||
]
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Safety", "shield")
|
||||
elif isinstance(self, MemoryBanksRoutingTable):
|
||||
return ("Memory", "memory_bank")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
|
@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def get_dataset(
|
||||
self, dataset_identifier: str
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
return self.get_object_by_identifier(dataset_identifier)
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
|
||||
await self.register_object(dataset_def)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFunctionDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue