unregister api for dataset

Test Plan:

**Type checker and check that the build compiles**

**Unit Tests**

**E2E Tests**

// Screenshots and videos
| Before | After |
|--|
| … | … |
This commit is contained in:
Sixian Yi 2024-11-22 19:44:23 -08:00
parent 2137b0af40
commit 58d664ab31
9 changed files with 134 additions and 0 deletions

View file

@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()]
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/datasets/unregister",
params={
"dataset_id": json.loads(dataset_id),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")

View file

@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets/unregister", method="POST")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...

View file

@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
)
await self.register_object(dataset)
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:

View file

@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -63,6 +63,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(Exception) as exc_info:
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(Exception) as exc_info:
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack