make distribution registry thread safe and other fixes (#449)

This PR makes the following changes:
1) Fixes the get_all and initialize impl to actually read the values
returned from the range call to kvstore and not keys.
2) The start_key and end_key are fixed to correct perform the range
query after the key format changes
3) Made the cache registry thread safe since there are multiple
initializes called for each routing table.

Tests:
* Start stack
* Register dataset
* Kill stack
* Bring stack up
* dataset list
```
 llama-stack-client datasets list
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| identifier   | provider_id   | metadata                                                                        | type    |
+==============+===============+=================================================================================+=========+
| alpaca       | huggingface-0 | {}                                                                              | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
| mmlu         | huggingface-0 | {'path': 'llama-stack/evals', 'name': 'evals__mmlu__details', 'split': 'train'} | dataset |
+--------------+---------------+---------------------------------------------------------------------------------+---------+
```

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-13 15:12:34 -08:00 committed by GitHub
parent 15dee2b8b8
commit e90ea1ab1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 148 additions and 48 deletions

View file

@ -302,7 +302,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type("dataset")
return await self.get_all_with_type(ResourceType.dataset.value)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -341,7 +341,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type("scoring_function")
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -355,8 +355,6 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
if params is None:
params = {}
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
@ -371,6 +369,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
@ -379,7 +378,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type("eval_task")
return await self.get_all_with_type(ResourceType.eval_task.value)
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)