mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
wip
This commit is contained in:
parent
821810657f
commit
aefa84e70a
5 changed files with 117 additions and 54 deletions
|
@ -8,8 +8,13 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
|
@ -41,6 +46,22 @@ async def datasetio_settings():
|
|||
}
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
print(mime_type)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
async def register_dataset(datasets_impl: Datasets):
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier="test_dataset",
|
||||
|
@ -48,62 +69,88 @@ async def register_dataset(datasets_impl: Datasets):
|
|||
url=URL(
|
||||
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||
),
|
||||
columns_schema={},
|
||||
dataset_schema={},
|
||||
)
|
||||
await datasets_impl.register_dataset(dataset)
|
||||
|
||||
|
||||
async def register_local_dataset(datasets_impl: Datasets):
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
|
||||
test_jsonl_url = data_url_from_file(str(test_file))
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier="test_dataset",
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
url=URL(
|
||||
uri=test_jsonl_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
await datasets_impl.register_dataset(dataset)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_datasets_list(datasetio_settings):
|
||||
# # NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# # but so far we don't have an unregister API unfortunately, so be careful
|
||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||
# response = await datasets_impl.list_datasets()
|
||||
# assert isinstance(response, list)
|
||||
# assert len(response) == 0
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_datasets_register(datasetio_settings):
|
||||
# # NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# # but so far we don't have an unregister API unfortunately, so be careful
|
||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||
# await register_dataset(datasets_impl)
|
||||
|
||||
# response = await datasets_impl.list_datasets()
|
||||
# assert isinstance(response, list)
|
||||
# assert len(response) == 1
|
||||
|
||||
# # register same dataset with same id again will fail
|
||||
# await register_dataset(datasets_impl)
|
||||
# response = await datasets_impl.list_datasets()
|
||||
# assert isinstance(response, list)
|
||||
# assert len(response) == 1
|
||||
# assert response[0].identifier == "test_dataset"
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_rows_paginated(datasetio_settings):
|
||||
# datasetio_impl = datasetio_settings["datasetio_impl"]
|
||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
||||
# await register_dataset(datasets_impl)
|
||||
|
||||
# response = await datasetio_impl.get_rows_paginated(
|
||||
# dataset_id="test_dataset",
|
||||
# rows_in_page=3,
|
||||
# )
|
||||
|
||||
# assert isinstance(response.rows, list)
|
||||
# assert len(response.rows) == 3
|
||||
# assert response.next_page_token == "3"
|
||||
|
||||
# # iterate over all rows
|
||||
# response = await datasetio_impl.get_rows_paginated(
|
||||
# dataset_id="test_dataset",
|
||||
# rows_in_page=10,
|
||||
# page_token=response.next_page_token,
|
||||
# )
|
||||
|
||||
# assert isinstance(response.rows, list)
|
||||
# assert len(response.rows) == 10
|
||||
# assert response.next_page_token == "13"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_datasets_list(datasetio_settings):
|
||||
async def test_datasets_validation(datasetio_settings):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
datasets_impl = datasetio_settings["datasets_impl"]
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_datasets_register(datasetio_settings):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
datasets_impl = datasetio_settings["datasets_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
# register same dataset with same id again will fail
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rows_paginated(datasetio_settings):
|
||||
datasetio_impl = datasetio_settings["datasetio_impl"]
|
||||
datasets_impl = datasetio_settings["datasets_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 3
|
||||
assert response.next_page_token == "3"
|
||||
|
||||
# iterate over all rows
|
||||
response = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=10,
|
||||
page_token=response.next_page_token,
|
||||
)
|
||||
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 10
|
||||
assert response.next_page_token == "13"
|
||||
await register_local_dataset(datasets_impl)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue