migrate scoring fns to resource (#422)

* fix after rebase

* remove print

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:28:48 -08:00 committed by GitHub
parent 3802edfc50
commit 0a3b3d5fb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 113 additions and 62 deletions

View file

@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable):
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
obj = cls(**obj.model_dump(), provider_id=provider_id)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("scoring_fn")
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFnDefWithProvider]:
return await self.get_object_by_identifier(name)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier(scoring_fn_id)
async def register_scoring_function(
self, function_def: ScoringFnDefWithProvider
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
await self.register_object(function_def)
if params is None:
params = {}
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFn(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):