mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-20 20:38:41 +00:00
unregister api for dataset
Test Plan: **Type checker and check that the build compiles** **Unit Tests** **E2E Tests** // Screenshots and videos | Before | After | |--| | … | … |
This commit is contained in:
parent
2137b0af40
commit
58d664ab31
9 changed files with 134 additions and 0 deletions
|
|
@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
|
|||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_impl=dataset_impl,
|
||||
)
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -63,6 +63,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -81,6 +81,18 @@ class TestDatasetIO:
|
|||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# unregister a dataset that does not exist
|
||||
await datasets_impl.unregister_dataset("test_dataset2")
|
||||
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rows_paginated(self, datasetio_stack):
|
||||
datasetio_impl, datasets_impl = datasetio_stack
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue