mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
migrate evals to resource (#421)
* migrate evals to resource * remove listing of providers's evals * change the order of params in register * fix after rebase * linter fix --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
b95cb5308f
commit
3802edfc50
5 changed files with 63 additions and 56 deletions
|
@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
eval_tasks = await p.list_eval_tasks()
|
||||
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]:
|
||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||
return await self.get_all_with_type("eval_task")
|
||||
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]:
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier(name)
|
||||
|
||||
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None:
|
||||
await self.register_object(eval_task_def)
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if provider_eval_task_id is None:
|
||||
provider_eval_task_id = eval_task_id
|
||||
eval_task = EvalTask(
|
||||
identifier=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_eval_task_id,
|
||||
)
|
||||
await self.register_object(eval_task)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue