From aefa84e70a6430c1e38475a7795142cc464a5d90 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 22 Oct 2024 20:00:43 -0700 Subject: [PATCH] wip --- llama_stack/apis/datasets/datasets.py | 2 +- .../meta_reference/datasetio/datasetio.py | 13 +- llama_stack/providers/registry/inference.py | 2 +- .../tests/datasetio/test_dataset.jsonl | 5 + .../tests/datasetio/test_datasetio.py | 149 ++++++++++++------ 5 files changed, 117 insertions(+), 54 deletions(-) create mode 100644 llama_stack/providers/tests/datasetio/test_dataset.jsonl diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e2b764d7f..7a56049bf 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -20,7 +20,7 @@ class DatasetDef(BaseModel): identifier: str = Field( description="A unique name for the dataset", ) - columns_schema: Dict[str, ParamType] = Field( + dataset_schema: Dict[str, ParamType] = Field( description="The schema definition for this dataset", ) url: URL diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index a8e648e46..3786b5857 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -3,10 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import io from typing import List, Optional import pandas - from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import parse_data_url from .config import MetaReferenceDatasetIOConfig @@ -57,6 +58,15 @@ class PandasDataframeDataset(BaseDataset): else: return self.df.iloc[idx].to_dict() + def validate_dataset_schema(self) -> None: + if self.df is None: + self.load() + + print(self.dataset_def.dataset_schema) + # get columns names + # columns = self.df[self.dataset_def.dataset_schema.keys()] + print(self.df.columns) + def load(self) -> None: if self.df is not None: return @@ -105,6 +115,7 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_def: DatasetDef, ) -> None: dataset_impl = PandasDataframeDataset(dataset_def) + dataset_impl.validate_dataset_schema() self.dataset_infos[dataset_def.identifier] = DatasetInfo( dataset_def=dataset_def, dataset_impl=dataset_impl, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 6f8bc2c6e..065b5ce76 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -36,7 +36,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=( META_REFERENCE_DEPS + [ - "fbgemm-gpu==0.8.0", + "fbgemm-gpu==1.0.0", ] ), module="llama_stack.providers.impls.meta_reference.inference", diff --git a/llama_stack/providers/tests/datasetio/test_dataset.jsonl b/llama_stack/providers/tests/datasetio/test_dataset.jsonl new file mode 100644 index 000000000..dfdfff458 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_dataset.jsonl @@ -0,0 +1,5 @@ +{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"} +{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"} +{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"} +{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"} +{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 85235a64b..4114207aa 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -8,8 +8,13 @@ import os import pytest import pytest_asyncio +from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import base64 +import mimetypes +from pathlib import Path + from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: @@ -41,6 +46,22 @@ async def datasetio_settings(): } +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + print(mime_type) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + async def register_dataset(datasets_impl: Datasets): dataset = DatasetDefWithProvider( identifier="test_dataset", @@ -48,62 +69,88 @@ async def register_dataset(datasets_impl: Datasets): url=URL( uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", ), - columns_schema={}, + dataset_schema={}, ) await datasets_impl.register_dataset(dataset) +async def register_local_dataset(datasets_impl: Datasets): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl" + test_jsonl_url = data_url_from_file(str(test_file)) + dataset = DatasetDefWithProvider( + identifier="test_dataset", + provider_id=os.environ["PROVIDER_ID"], + url=URL( + uri=test_jsonl_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + await datasets_impl.register_dataset(dataset) + + +# @pytest.mark.asyncio +# async def test_datasets_list(datasetio_settings): +# # NOTE: this needs you to ensure that you are starting from a clean state +# # but so far we don't have an unregister API unfortunately, so be careful +# datasets_impl = datasetio_settings["datasets_impl"] +# response = await datasets_impl.list_datasets() +# assert isinstance(response, list) +# assert len(response) == 0 + + +# @pytest.mark.asyncio +# async def test_datasets_register(datasetio_settings): +# # NOTE: this needs you to ensure that you are starting from a clean state +# # but so far we don't have an unregister API unfortunately, so be careful +# datasets_impl = datasetio_settings["datasets_impl"] +# await register_dataset(datasets_impl) + +# response = await datasets_impl.list_datasets() +# assert isinstance(response, list) +# assert len(response) == 1 + +# # register same dataset with same id again will fail +# await register_dataset(datasets_impl) +# response = await datasets_impl.list_datasets() +# assert isinstance(response, list) +# assert len(response) == 1 +# assert response[0].identifier == "test_dataset" + + +# @pytest.mark.asyncio +# async def test_get_rows_paginated(datasetio_settings): +# datasetio_impl = datasetio_settings["datasetio_impl"] +# datasets_impl = datasetio_settings["datasets_impl"] +# await register_dataset(datasets_impl) + +# response = await datasetio_impl.get_rows_paginated( +# dataset_id="test_dataset", +# rows_in_page=3, +# ) + +# assert isinstance(response.rows, list) +# assert len(response.rows) == 3 +# assert response.next_page_token == "3" + +# # iterate over all rows +# response = await datasetio_impl.get_rows_paginated( +# dataset_id="test_dataset", +# rows_in_page=10, +# page_token=response.next_page_token, +# ) + +# assert isinstance(response.rows, list) +# assert len(response.rows) == 10 +# assert response.next_page_token == "13" + + @pytest.mark.asyncio -async def test_datasets_list(datasetio_settings): +async def test_datasets_validation(datasetio_settings): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful datasets_impl = datasetio_settings["datasets_impl"] - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 - - -@pytest.mark.asyncio -async def test_datasets_register(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - - # register same dataset with same id again will fail - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - -@pytest.mark.asyncio -async def test_get_rows_paginated(datasetio_settings): - datasetio_impl = datasetio_settings["datasetio_impl"] - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - # iterate over all rows - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=10, - page_token=response.next_page_token, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 10 - assert response.next_page_token == "13" + await register_local_dataset(datasets_impl)