unregister

This commit is contained in:
Xi Yan 2025-03-15 16:37:46 -07:00
parent cf225c9710
commit 4fee3af91f
4 changed files with 79 additions and 96 deletions

View file

@ -13,7 +13,7 @@ import pytest
@pytest.mark.parametrize(
"purpose, source, provider_id",
"purpose, source, provider_id, limit",
[
(
"eval/messages-answer",
@ -22,16 +22,48 @@ import pytest
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
"huggingface",
10,
),
(
"eval/messages-answer",
{
"type": "rows",
"rows": [
{
"messages": [{"role": "user", "content": "Hello, world!"}],
"answer": "Hello, world!",
},
{
"messages": [
{
"role": "user",
"content": "What is the capital of France?",
}
],
"answer": "Paris",
},
],
},
"localfs",
2,
),
],
)
def test_register_dataset(llama_stack_client, purpose, source, provider_id):
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
assert dataset.identifier is not None
assert dataset.provider_id == provider_id
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10)
assert len(iterrow_response.data) == 10
assert iterrow_response.next_index is not None
iterrow_response = llama_stack_client.datasets.iterrows(
dataset.identifier, limit=limit
)
assert len(iterrow_response.data) == limit
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier in [d.identifier for d in dataset_list]
llama_stack_client.datasets.unregister(dataset.identifier)
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier not in [d.identifier for d in dataset_list]

View file

@ -10,11 +10,27 @@ from rich.pretty import pprint
def test_register_dataset():
client = LlamaStackClient(base_url="http://localhost:8321")
# dataset = client.datasets.register(
# purpose="eval/messages-answer",
# source={
# "type": "uri",
# "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
# },
# )
dataset = client.datasets.register(
purpose="eval/messages-answer",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
"type": "rows",
"rows": [
{
"messages": [{"role": "user", "content": "Hello, world!"}],
"answer": "Hello, world!",
},
{
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"answer": "Paris",
},
],
},
)
dataset_id = dataset.identifier