forked from phoenix-oss/llama-stack-mirror
unregister API for dataset (#507)
# What does this PR do? 1) Implement `unregister_dataset(dataset_id)` API in both llama stack routing table and providers: It removes {dataset_id -> Dataset} mapping from routing table and removes the dataset_id references in provider as well (ex. for huggingface, we use a KV store to store the dataset id => dataset. we delete it during unregistering as well) 2) expose the datasets/unregister_dataset api endpoint ## Test Plan **Unit test:** ` pytest llama_stack/providers/tests/datasetio/test_datasetio.py -m "huggingface" -v -s --tb=short --disable-warnings ` **Test on endpoint:** tested llama stack using an ollama distribution template: 1) start an ollama server 2) Start a llama stack server with the default ollama distribution config + dataset/datasetsio APIs + datasetio provider ``` ---- .../ollama-run.yaml ... apis: - agents - inference - memory - safety - telemetry - datasetio - datasets providers: datasetio: - provider_id: localfs provider_type: inline::localfs config: {} ... ``` saw that the new API showed up in startup script ``` Serving API datasets GET /alpha/datasets/get GET /alpha/datasets/list POST /alpha/datasets/register POST /alpha/datasets/unregister ``` 3) query `/alpha/datasets/unregister` through curl (since we have not implemented unregister api in llama stack client) ``` (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian2 --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian │ localfs │ {} │ dataset │ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ metadata ┃ type ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩ │ sixian2 │ localfs │ {} │ dataset │ └────────────┴─────────────┴──────────┴─────────┘ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian2"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ``` ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
64c6df8392
commit
caf1dac114
9 changed files with 134 additions and 0 deletions
|
@ -2291,6 +2291,39 @@
|
||||||
"required": true
|
"required": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"/alpha/datasets/unregister": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Datasets"
|
||||||
|
],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "X-LlamaStack-ProviderData",
|
||||||
|
"in": "header",
|
||||||
|
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/UnregisterDatasetRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
@ -7917,6 +7950,18 @@
|
||||||
"required": [
|
"required": [
|
||||||
"model_id"
|
"model_id"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"UnregisterDatasetRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"dataset_id": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"dataset_id"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"responses": {}
|
"responses": {}
|
||||||
|
@ -8529,6 +8574,10 @@
|
||||||
"name": "UnregisterModelRequest",
|
"name": "UnregisterModelRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "UnregisterDatasetRequest",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterDatasetRequest\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "UnstructuredLogEvent",
|
"name": "UnstructuredLogEvent",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
||||||
|
@ -8718,6 +8767,7 @@
|
||||||
"URL",
|
"URL",
|
||||||
"UnregisterMemoryBankRequest",
|
"UnregisterMemoryBankRequest",
|
||||||
"UnregisterModelRequest",
|
"UnregisterModelRequest",
|
||||||
|
"UnregisterDatasetRequest",
|
||||||
"UnstructuredLogEvent",
|
"UnstructuredLogEvent",
|
||||||
"UserMessage",
|
"UserMessage",
|
||||||
"VectorMemoryBank",
|
"VectorMemoryBank",
|
||||||
|
|
|
@ -3253,6 +3253,14 @@ components:
|
||||||
required:
|
required:
|
||||||
- model_id
|
- model_id
|
||||||
type: object
|
type: object
|
||||||
|
UnregisterDatasetRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
dataset_id:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- dataset_id
|
||||||
|
type: object
|
||||||
UnstructuredLogEvent:
|
UnstructuredLogEvent:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3789,6 +3797,27 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Datasets
|
- Datasets
|
||||||
|
/alpha/datasets/unregister:
|
||||||
|
post:
|
||||||
|
parameters:
|
||||||
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
adapter servicing the API
|
||||||
|
in: header
|
||||||
|
name: X-LlamaStack-ProviderData
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/UnregisterDatasetRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
tags:
|
||||||
|
- Datasets
|
||||||
/alpha/eval-tasks/get:
|
/alpha/eval-tasks/get:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -5242,6 +5271,9 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
|
||||||
/>
|
/>
|
||||||
name: UnregisterModelRequest
|
name: UnregisterModelRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterDatasetRequest"
|
||||||
|
/>
|
||||||
|
name: UnregisterDatasetRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
||||||
/>
|
/>
|
||||||
name: UnstructuredLogEvent
|
name: UnstructuredLogEvent
|
||||||
|
@ -5418,6 +5450,7 @@ x-tagGroups:
|
||||||
- URL
|
- URL
|
||||||
- UnregisterMemoryBankRequest
|
- UnregisterMemoryBankRequest
|
||||||
- UnregisterModelRequest
|
- UnregisterModelRequest
|
||||||
|
- UnregisterDatasetRequest
|
||||||
- UnstructuredLogEvent
|
- UnstructuredLogEvent
|
||||||
- UserMessage
|
- UserMessage
|
||||||
- VectorMemoryBank
|
- VectorMemoryBank
|
||||||
|
|
|
@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
|
||||||
|
|
||||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||||
|
|
||||||
|
async def unregister_dataset(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
) -> None:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.delete(
|
||||||
|
f"{self.base_url}/datasets/unregister",
|
||||||
|
params={
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
client = DatasetsClient(f"http://{host}:{port}")
|
client = DatasetsClient(f"http://{host}:{port}")
|
||||||
|
|
|
@ -64,3 +64,9 @@ class Datasets(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/datasets/list", method="GET")
|
@webmethod(route="/datasets/list", method="GET")
|
||||||
async def list_datasets(self) -> List[Dataset]: ...
|
async def list_datasets(self) -> List[Dataset]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/datasets/unregister", method="POST")
|
||||||
|
async def unregister_dataset(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
) -> None: ...
|
||||||
|
|
|
@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
return await p.unregister_memory_bank(obj.identifier)
|
return await p.unregister_memory_bank(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.unregister_dataset(obj.identifier)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
)
|
)
|
||||||
await self.register_object(dataset)
|
await self.register_object(dataset)
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
await self.unregister_object(dataset)
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
|
|
|
@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||||
|
|
|
@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
dataset_impl=dataset_impl,
|
dataset_impl=dataset_impl,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
await self.kvstore.delete(key=key)
|
||||||
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -81,6 +81,18 @@ class TestDatasetIO:
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
assert response[0].identifier == "test_dataset"
|
assert response[0].identifier == "test_dataset"
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
# unregister a dataset that does not exist
|
||||||
|
await datasets_impl.unregister_dataset("test_dataset2")
|
||||||
|
|
||||||
|
await datasets_impl.unregister_dataset("test_dataset")
|
||||||
|
response = await datasets_impl.list_datasets()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == 0
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await datasets_impl.unregister_dataset("test_dataset")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_rows_paginated(self, datasetio_stack):
|
async def test_get_rows_paginated(self, datasetio_stack):
|
||||||
datasetio_impl, datasets_impl = datasetio_stack
|
datasetio_impl, datasets_impl = datasetio_stack
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue