diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 090253804..4f220ea1e 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -2291,6 +2291,39 @@
"required": true
}
}
+ },
+ "/alpha/datasets/unregister": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ }
+ },
+ "tags": [
+ "Datasets"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/UnregisterDatasetRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@@ -7917,6 +7950,18 @@
"required": [
"model_id"
]
+ },
+ "UnregisterDatasetRequest": {
+ "type": "object",
+ "properties": {
+ "dataset_id": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "dataset_id"
+ ]
}
},
"responses": {}
@@ -8529,6 +8574,10 @@
"name": "UnregisterModelRequest",
"description": ""
},
+ {
+ "name": "UnregisterDatasetRequest",
+ "description": ""
+ },
{
"name": "UnstructuredLogEvent",
"description": ""
@@ -8718,6 +8767,7 @@
"URL",
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
+ "UnregisterDatasetRequest",
"UnstructuredLogEvent",
"UserMessage",
"VectorMemoryBank",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 8ffd9fdef..6564ddf3f 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -3253,6 +3253,14 @@ components:
required:
- model_id
type: object
+ UnregisterDatasetRequest:
+ additionalProperties: false
+ properties:
+ dataset_id:
+ type: string
+ required:
+ - dataset_id
+ type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@@ -3789,6 +3797,27 @@ paths:
description: OK
tags:
- Datasets
+ /alpha/datasets/unregister:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/UnregisterDatasetRequest'
+ required: true
+ responses:
+ '200':
+ description: OK
+ tags:
+ - Datasets
/alpha/eval-tasks/get:
get:
parameters:
@@ -5242,6 +5271,9 @@ tags:
- description:
name: UnregisterModelRequest
+- description:
+ name: UnregisterDatasetRequest
- description:
name: UnstructuredLogEvent
@@ -5418,6 +5450,7 @@ x-tagGroups:
- URL
- UnregisterMemoryBankRequest
- UnregisterModelRequest
+ - UnregisterDatasetRequest
- UnstructuredLogEvent
- UserMessage
- VectorMemoryBank
diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py
index 9e5891e74..c379a49fb 100644
--- a/llama_stack/apis/datasets/client.py
+++ b/llama_stack/apis/datasets/client.py
@@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()]
+ async def unregister_dataset(
+ self,
+ dataset_id: str,
+ ) -> None:
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(
+ f"{self.base_url}/datasets/unregister",
+ params={
+ "dataset_id": dataset_id,
+ },
+ headers={"Content-Type": "application/json"},
+ timeout=60,
+ )
+ response.raise_for_status()
+
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py
index 2ab958782..e1ac4af21 100644
--- a/llama_stack/apis/datasets/datasets.py
+++ b/llama_stack/apis/datasets/datasets.py
@@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
+
+ @webmethod(route="/datasets/unregister", method="POST")
+ async def unregister_dataset(
+ self,
+ dataset_id: str,
+ ) -> None: ...
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 4df693b26..2fb5a5e1c 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
+ elif api == Api.datasetio:
+ return await p.unregister_dataset(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
)
await self.register_object(dataset)
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ dataset = await self.get_dataset(dataset_id)
+ if dataset is None:
+ raise ValueError(f"Dataset {dataset_id} not found")
+ await self.unregister_object(dataset)
+
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 080204e45..8e89bcc72 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
+ async def unregister_dataset(self, dataset_id: str) -> None: ...
+
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...
diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
index 4de1850ae..010610056 100644
--- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py
+++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
@@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ del self.dataset_infos[dataset_id]
+
async def get_rows_paginated(
self,
dataset_id: str,
diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
index c2e4506bf..cdd5d9cd3 100644
--- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
+++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
@@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ key = f"{DATASETS_PREFIX}{dataset_id}"
+ await self.kvstore.delete(key=key)
+ del self.dataset_infos[dataset_id]
+
async def get_rows_paginated(
self,
dataset_id: str,
diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py
index dd2cbd019..7d88b6115 100644
--- a/llama_stack/providers/tests/datasetio/test_datasetio.py
+++ b/llama_stack/providers/tests/datasetio/test_datasetio.py
@@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
+ with pytest.raises(Exception) as exc_info:
+ # unregister a dataset that does not exist
+ await datasets_impl.unregister_dataset("test_dataset2")
+
+ await datasets_impl.unregister_dataset("test_dataset")
+ response = await datasets_impl.list_datasets()
+ assert isinstance(response, list)
+ assert len(response) == 0
+
+ with pytest.raises(Exception) as exc_info:
+ await datasets_impl.unregister_dataset("test_dataset")
+
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack