unregister api for dataset

Test Plan:

**Type checker and check that the build compiles**

**Unit Tests**

**E2E Tests**

// Screenshots and videos
| Before | After |
|--|
| … | … |
This commit is contained in:
Sixian Yi 2024-11-22 19:44:23 -08:00
parent 2137b0af40
commit 58d664ab31
9 changed files with 134 additions and 0 deletions

View file

@ -2291,6 +2291,39 @@
"required": true "required": true
} }
} }
},
"/alpha/datasets/unregister": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"Datasets"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UnregisterDatasetRequest"
}
}
},
"required": true
}
}
} }
}, },
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@ -7916,6 +7949,18 @@
"required": [ "required": [
"model_id" "model_id"
] ]
},
"UnregisterDatasetRequest": {
"type": "object",
"properties": {
"dataset_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"dataset_id"
]
} }
}, },
"responses": {} "responses": {}
@ -8528,6 +8573,10 @@
"name": "UnregisterModelRequest", "name": "UnregisterModelRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
}, },
{
"name": "UnregisterDatasetRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterDatasetRequest\" />"
},
{ {
"name": "UnstructuredLogEvent", "name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
@ -8717,6 +8766,7 @@
"URL", "URL",
"UnregisterMemoryBankRequest", "UnregisterMemoryBankRequest",
"UnregisterModelRequest", "UnregisterModelRequest",
"UnregisterDatasetRequest",
"UnstructuredLogEvent", "UnstructuredLogEvent",
"UserMessage", "UserMessage",
"VectorMemoryBank", "VectorMemoryBank",

View file

@ -3252,6 +3252,14 @@ components:
required: required:
- model_id - model_id
type: object type: object
UnregisterDatasetRequest:
additionalProperties: false
properties:
dataset_id:
type: string
required:
- dataset_id
type: object
UnstructuredLogEvent: UnstructuredLogEvent:
additionalProperties: false additionalProperties: false
properties: properties:
@ -3789,6 +3797,27 @@ paths:
description: OK description: OK
tags: tags:
- Datasets - Datasets
/alpha/datasets/unregister:
post:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UnregisterDatasetRequest'
required: true
responses:
'200':
description: OK
tags:
- Datasets
/alpha/eval-tasks/get: /alpha/eval-tasks/get:
get: get:
parameters: parameters:
@ -5242,6 +5271,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
/> />
name: UnregisterModelRequest name: UnregisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterDatasetRequest"
/>
name: UnregisterDatasetRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent" - description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/> />
name: UnstructuredLogEvent name: UnstructuredLogEvent
@ -5418,6 +5450,7 @@ x-tagGroups:
- URL - URL
- UnregisterMemoryBankRequest - UnregisterMemoryBankRequest
- UnregisterModelRequest - UnregisterModelRequest
- UnregisterDatasetRequest
- UnstructuredLogEvent - UnstructuredLogEvent
- UserMessage - UserMessage
- VectorMemoryBank - VectorMemoryBank

View file

@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()] return [DatasetDefWithProvider(**x) for x in response.json()]
async def unregister_dataset(
self,
dataset_id: str,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/datasets/unregister",
params={
"dataset_id": json.loads(dataset_id),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
async def run_main(host: str, port: int): async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}") client = DatasetsClient(f"http://{host}:{port}")

View file

@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET") @webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ... async def list_datasets(self) -> List[Dataset]: ...
@webmethod(route="/datasets/unregister", method="POST")
async def unregister_dataset(
self,
dataset_id: str,
) -> None: ...

View file

@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier) return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference: elif api == Api.inference:
return await p.unregister_model(obj.identifier) return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
else: else:
raise ValueError(f"Unregister not supported for {api}") raise ValueError(f"Unregister not supported for {api}")
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
) )
await self.register_object(dataset) await self.register_object(dataset)
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:

View file

@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ... async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ... async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl, dataset_impl=dataset_impl,
) )
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def get_rows_paginated(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -63,6 +63,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) )
self.dataset_infos[dataset_def.identifier] = dataset_def self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated( async def get_rows_paginated(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1 assert len(response) == 1
assert response[0].identifier == "test_dataset" assert response[0].identifier == "test_dataset"
with pytest.raises(Exception) as exc_info:
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(Exception) as exc_info:
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack): async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack datasetio_impl, datasets_impl = datasetio_stack