migrate evals to resource

This commit is contained in:
Dinesh Yeduguru 2024-11-11 12:38:22 -08:00
parent b95cb5308f
commit 63b5eb929f
5 changed files with 57 additions and 40 deletions

View file

@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.eval:
p.eval_task_store = self
eval_tasks = await p.list_eval_tasks()
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]:
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]:
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier(name)
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None:
await self.register_object(eval_task_def)
async def register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_id: Optional[str] = None,
provider_eval_task_id: Optional[str] = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_eval_task_id is None:
provider_eval_task_id = eval_task_id
eval_task = EvalTask(
identifier=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_eval_task_id,
)
await self.register_object(eval_task)