forked from phoenix-oss/llama-stack-mirror
unregister API for dataset (#507)
# What does this PR do? 1) Implement `unregister_dataset(dataset_id)` API in both llama stack routing table and providers: It removes {dataset_id -> Dataset} mapping from routing table and removes the dataset_id references in provider as well (ex. for huggingface, we use a KV store to store the dataset id => dataset. we delete it during unregistering as well) 2) expose the datasets/unregister_dataset api endpoint ## Test Plan **Unit test:** ` pytest llama_stack/providers/tests/datasetio/test_datasetio.py -m "huggingface" -v -s --tb=short --disable-warnings ` **Test on endpoint:** tested llama stack using an ollama distribution template: 1) start an ollama server 2) Start a llama stack server with the default ollama distribution config + dataset/datasetsio APIs + datasetio provider ``` ---- .../ollama-run.yaml ... apis: - agents - inference - memory - safety - telemetry - datasetio - datasets providers: datasetio: - provider_id: localfs provider_type: inline::localfs config: {} ... ``` saw that the new API showed up in startup script ``` Serving API datasets GET /alpha/datasets/get GET /alpha/datasets/list POST /alpha/datasets/register POST /alpha/datasets/unregister ``` 3) query `/alpha/datasets/unregister` through curl (since we have not implemented unregister api in llama stack client) ``` (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian2 --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian2"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ``` ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
64c6df8392
commit
caf1dac114
9 changed files with 134 additions and 0 deletions
|
@ -2291,6 +2291,39 @@
|
|||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/alpha/datasets/unregister": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Datasets"
|
||||
],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "X-LlamaStack-ProviderData",
|
||||
"in": "header",
|
||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/UnregisterDatasetRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
||||
|
@ -7917,6 +7950,18 @@
|
|||
"required": [
|
||||
"model_id"
|
||||
]
|
||||
},
|
||||
"UnregisterDatasetRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataset_id": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"dataset_id"
|
||||
]
|
||||
}
|
||||
},
|
||||
"responses": {}
|
||||
|
@ -8529,6 +8574,10 @@
|
|||
"name": "UnregisterModelRequest",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "UnregisterDatasetRequest",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterDatasetRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "UnstructuredLogEvent",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
||||
|
@ -8718,6 +8767,7 @@
|
|||
"URL",
|
||||
"UnregisterMemoryBankRequest",
|
||||
"UnregisterModelRequest",
|
||||
"UnregisterDatasetRequest",
|
||||
"UnstructuredLogEvent",
|
||||
"UserMessage",
|
||||
"VectorMemoryBank",
|
||||
|
|
|
@ -3253,6 +3253,14 @@ components:
|
|||
required:
|
||||
- model_id
|
||||
type: object
|
||||
UnregisterDatasetRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
dataset_id:
|
||||
type: string
|
||||
required:
|
||||
- dataset_id
|
||||
type: object
|
||||
UnstructuredLogEvent:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -3789,6 +3797,27 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Datasets
|
||||
/alpha/datasets/unregister:
|
||||
post:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UnregisterDatasetRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
tags:
|
||||
- Datasets
|
||||
/alpha/eval-tasks/get:
|
||||
get:
|
||||
parameters:
|
||||
|
@ -5242,6 +5271,9 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
|
||||
/>
|
||||
name: UnregisterModelRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterDatasetRequest"
|
||||
/>
|
||||
name: UnregisterDatasetRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
||||
/>
|
||||
name: UnstructuredLogEvent
|
||||
|
@ -5418,6 +5450,7 @@ x-tagGroups:
|
|||
- URL
|
||||
- UnregisterMemoryBankRequest
|
||||
- UnregisterModelRequest
|
||||
- UnregisterDatasetRequest
|
||||
- UnstructuredLogEvent
|
||||
- UserMessage
|
||||
- VectorMemoryBank
|
||||
|
|
|
@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
|
|||
|
||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{self.base_url}/datasets/unregister",
|
||||
params={
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
|
|
@ -64,3 +64,9 @@ class Datasets(Protocol):
|
|||
|
||||
@webmethod(route="/datasets/list", method="GET")
|
||||
async def list_datasets(self) -> List[Dataset]: ...
|
||||
|
||||
@webmethod(route="/datasets/unregister", method="POST")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
|
|
|
@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_memory_bank(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
)
|
||||
await self.register_object(dataset)
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
await self.unregister_object(dataset)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
|
|
|
@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
|
|||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
|
|
@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_impl=dataset_impl,
|
||||
)
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -81,6 +81,18 @@ class TestDatasetIO:
|
|||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# unregister a dataset that does not exist
|
||||
await datasets_impl.unregister_dataset("test_dataset2")
|
||||
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await datasets_impl.unregister_dataset("test_dataset")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rows_paginated(self, datasetio_stack):
|
||||
datasetio_impl, datasets_impl = datasetio_stack
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue