mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
make distribution registry thread safe and other fixes (#449)
This PR makes the following changes: 1) Fixes the get_all and initialize impl to actually read the values returned from the range call to kvstore and not keys. 2) The start_key and end_key are fixed to correct perform the range query after the key format changes 3) Made the cache registry thread safe since there are multiple initializes called for each routing table. Tests: * Start stack * Register dataset * Kill stack * Bring stack up * dataset list ``` llama-stack-client datasets list +--------------+---------------+---------------------------------------------------------------------------------+---------+ | identifier | provider_id | metadata | type | +==============+===============+=================================================================================+=========+ | alpaca | huggingface-0 | {} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ | mmlu | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset | +--------------+---------------+---------------------------------------------------------------------------------+---------+ ``` Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
15dee2b8b8
commit
e90ea1ab1e
3 changed files with 148 additions and 48 deletions
|
@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[Dataset]:
|
||||
return await self.get_all_with_type("dataset")
|
||||
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type("scoring_function")
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
|
@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
|
@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||
return await self.get_all_with_type("eval_task")
|
||||
return await self.get_all_with_type(ResourceType.eval_task.value)
|
||||
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier("eval_task", name)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue