mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 11:24:19 +00:00
[Evals API][4/n] evals with generation meta-reference impl (#303)
* wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * evals with generation * add all rows scores to ScoringResult * minor typing * bugfix * scoring function def rename * rebase name * refactor * address comments * Update iOS inference instructions for new quantization * Small updates to quantization config * Fix score threshold in faiss * Bump version to 0.0.45 * Handle both ipv6 and ipv4 interfaces together * update manifest for build templates * Update getting_started.md * chatcompletion & completion input type validation * inclusion->subsetof * error checking * scoring_function -> scoring_fn rename, scorer -> scoring_fn rename * address comments * [Evals API][5/n] fixes to generate openapi spec (#323) * generate openapi * typing comment, dataset -> dataset_id * remove custom type * sample eval run.yaml --------- Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
426d821e7f
commit
abdf7cddf3
31 changed files with 3371 additions and 1296 deletions
|
@ -34,7 +34,7 @@ RoutableObject = Union[
|
|||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
DatasetDef,
|
||||
ScoringFunctionDef,
|
||||
ScoringFnDef,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
|
@ -42,7 +42,7 @@ RoutableObjectWithProvider = Union[
|
|||
ShieldDefWithProvider,
|
||||
MemoryBankDefWithProvider,
|
||||
DatasetDefWithProvider,
|
||||
ScoringFunctionDefWithProvider,
|
||||
ScoringFnDefWithProvider,
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
|
@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.datasetio: DatasetIO,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.scoring: Scoring,
|
||||
Api.eval: Eval,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(
|
||||
[
|
||||
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
|
||||
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in scoring_functions
|
||||
]
|
||||
)
|
||||
|
@ -239,7 +239,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
|
@ -247,10 +247,10 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
|||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFunctionDefWithProvider]:
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue