unregister api for dataset

Test Plan:

**Type checker and check that the build compiles**

**Unit Tests**

**E2E Tests**

// Screenshots and videos
| Before | After |
|--|
| … | … |
This commit is contained in:
Sixian Yi 2024-11-22 19:44:23 -08:00
parent 2137b0af40
commit 58d664ab31
9 changed files with 134 additions and 0 deletions

View file

@ -2291,6 +2291,39 @@
"required": true
}
}
},
"/alpha/datasets/unregister": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"Datasets"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UnregisterDatasetRequest"
}
}
},
"required": true
}
}
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@ -7916,6 +7949,18 @@
"required": [
"model_id"
]
},
"UnregisterDatasetRequest": {
"type": "object",
"properties": {
"dataset_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"dataset_id"
]
}
},
"responses": {}
@ -8528,6 +8573,10 @@
"name": "UnregisterModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
},
{
"name": "UnregisterDatasetRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterDatasetRequest\" />"
},
{
"name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
@ -8717,6 +8766,7 @@
"URL",
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
"UnregisterDatasetRequest",
"UnstructuredLogEvent",
"UserMessage",
"VectorMemoryBank",

View file

@ -3252,6 +3252,14 @@ components:
required:
- model_id
type: object
UnregisterDatasetRequest:
additionalProperties: false
properties:
dataset_id:
type: string
required:
- dataset_id
type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@ -3789,6 +3797,27 @@ paths:
description: OK
tags:
- Datasets
/alpha/datasets/unregister:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UnregisterDatasetRequest'
required: true
responses:
'200':
description: OK
tags:
- Datasets
/alpha/eval-tasks/get:
get:
parameters:
@ -5242,6 +5271,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
/>
name: UnregisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterDatasetRequest"
/>
name: UnregisterDatasetRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/>
name: UnstructuredLogEvent
@ -5418,6 +5450,7 @@ x-tagGroups:
- URL
- UnregisterMemoryBankRequest
- UnregisterModelRequest
- UnregisterDatasetRequest
- UnstructuredLogEvent
- UserMessage
- VectorMemoryBank

View file

@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()]
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/datasets/unregister",
params={
"dataset_id": json.loads(dataset_id),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")

View file

@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets/unregister", method="POST")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...

View file

@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
)
await self.register_object(dataset)
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:

View file

@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -63,6 +63,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(Exception) as exc_info:
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(Exception) as exc_info:
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack