From 683a370d23c8cd4e93fc640cdac5ecf25bba243b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 6 Nov 2024 10:03:49 -0800 Subject: [PATCH] wip tests --- llama_stack/apis/eval/eval.py | 7 +- llama_stack/apis/eval_tasks/__init__.py | 7 + .../tests/datasetio/test_datasetio.py | 193 ++++++------------ .../tests/datasetio/test_datasetio_old.py | 148 ++++++++++++++ 4 files changed, 227 insertions(+), 128 deletions(-) create mode 100644 llama_stack/apis/eval_tasks/__init__.py create mode 100644 llama_stack/providers/tests/datasetio/test_datasetio_old.py diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index d333d6fec..dc43366a7 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -51,6 +51,11 @@ class AppEvalTaskConfig(BaseModel): # we could optinally add any specific dataset config here +EvalTaskConfig = Annotated[ + Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") +] + + @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] @@ -70,7 +75,7 @@ class Eval(Protocol): async def run_eval( self, eval_task_def: EvalTaskDef, # type: ignore - eval_task_config: AppEvalTaskConfig, # type: ignore + eval_task_config: EvalTaskConfig, # type: ignore ) -> Job: ... @webmethod(route="/eval/evaluate_rows", method="POST") diff --git a/llama_stack/apis/eval_tasks/__init__.py b/llama_stack/apis/eval_tasks/__init__.py new file mode 100644 index 000000000..7ca216706 --- /dev/null +++ b/llama_stack/apis/eval_tasks/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval_tasks import * # noqa: F401 F403 diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 866b1e270..4cc46e619 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -3,146 +3,85 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import pytest -import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -import base64 -import mimetypes -from pathlib import Path - -from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/memory/test_memory.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def datasetio_settings(): - impls = await resolve_impls_for_test( - Api.datasetio, - ) - return { - "datasetio_impl": impls[Api.datasetio], - "datasets_impl": impls[Api.datasets], - } - - -def data_url_from_file(file_path: str) -> str: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - with open(file_path, "rb") as file: - file_content = file.read() - - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) - - data_url = f"data:{mime_type};base64,{base64_content}" - - return data_url - - -async def register_dataset( - datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" -): - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" - test_url = data_url_from_file(str(test_file)) - - if for_generation: - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - "chat_completion_input": ChatCompletionInputType(), - } - else: - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - "generated_answer": StringType(), - } - - dataset = DatasetDefWithProvider( - identifier=dataset_id, - provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) - or os.environ["PROVIDER_ID"], - url=URL( - uri=test_url, +@pytest.fixture +def sample_documents(): + return [ + MemoryBankDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, ), - dataset_schema=dataset_schema, - ) - await datasets_impl.register_dataset(dataset) + MemoryBankDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + MemoryBankDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + MemoryBankDocument( + document_id="doc4", + content="Neural networks are inspired by biological neural networks.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + ] -@pytest.mark.asyncio -async def test_datasets_list(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 - - -@pytest.mark.asyncio -async def test_datasets_register(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - - # register same dataset with same id again will fail - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - -@pytest.mark.asyncio -async def test_get_rows_paginated(datasetio_settings): - datasetio_impl = datasetio_settings["datasetio_impl"] - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, +async def register_memory_bank(banks_impl: MemoryBanks): + bank = VectorMemoryBankDef( + identifier="test_bank", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, ) - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" + await banks_impl.register_memory_bank(bank) - # iterate over all rows - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" +class TestMemory: + @pytest.mark.asyncio + async def test_banks_list(self, memory_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, banks_impl = memory_stack + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 0 + + @pytest.mark.asyncio + async def test_banks_register(self, memory_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, banks_impl = memory_stack + bank = VectorMemoryBankDef( + identifier="test_bank_no_provider", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + # register same memory bank with same id again will fail + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 diff --git a/llama_stack/providers/tests/datasetio/test_datasetio_old.py b/llama_stack/providers/tests/datasetio/test_datasetio_old.py new file mode 100644 index 000000000..866b1e270 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_datasetio_old.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import os + +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 +import base64 +import mimetypes +from pathlib import Path + +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def datasetio_settings(): + impls = await resolve_impls_for_test( + Api.datasetio, + ) + return { + "datasetio_impl": impls[Api.datasetio], + "datasets_impl": impls[Api.datasets], + } + + +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + +async def register_dataset( + datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" +): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) + + if for_generation: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "chat_completion_input": ChatCompletionInputType(), + } + else: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + } + + dataset = DatasetDefWithProvider( + identifier=dataset_id, + provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) + or os.environ["PROVIDER_ID"], + url=URL( + uri=test_url, + ), + dataset_schema=dataset_schema, + ) + await datasets_impl.register_dataset(dataset) + + +@pytest.mark.asyncio +async def test_datasets_list(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + + +@pytest.mark.asyncio +async def test_datasets_register(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + + # register same dataset with same id again will fail + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" + + +@pytest.mark.asyncio +async def test_get_rows_paginated(datasetio_settings): + datasetio_impl = datasetio_settings["datasetio_impl"] + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" + + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) + + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5"