mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
migrate scoring fns to resource (#422)
* fix after rebase * remove print --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
3802edfc50
commit
0a3b3d5fb6
16 changed files with 113 additions and 62 deletions
|
@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
return await self.get_all_with_type("scoring_fn")
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return await self.get_object_by_identifier(name)
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier(scoring_fn_id)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
if params is None:
|
||||
params = {}
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue