diff --git a/docs/source/distributions/self_hosted_distro/fireworks.md b/docs/source/distributions/self_hosted_distro/fireworks.md index 1fcd6f7af..9592a18fe 100644 --- a/docs/source/distributions/self_hosted_distro/fireworks.md +++ b/docs/source/distributions/self_hosted_distro/fireworks.md @@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 8f23cef43..fb3f9164a 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | diff --git a/docs/source/distributions/self_hosted_distro/remote-vllm.md b/docs/source/distributions/self_hosted_distro/remote-vllm.md index 01f38807b..b7e155385 100644 --- a/docs/source/distributions/self_hosted_distro/remote-vllm.md +++ b/docs/source/distributions/self_hosted_distro/remote-vllm.md @@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | diff --git a/docs/source/distributions/self_hosted_distro/together.md b/docs/source/distributions/self_hosted_distro/together.md index f361e93c7..fa02199b0 100644 --- a/docs/source/distributions/self_hosted_distro/together.md +++ b/docs/source/distributions/self_hosted_distro/together.md @@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 80e9ecb7c..73f9c9672 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -366,7 +366,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" ) if metadata is None: metadata = {} diff --git a/llama_stack/providers/tests/env.py b/llama_stack/env.py similarity index 100% rename from llama_stack/providers/tests/env.py rename to llama_stack/env.py diff --git a/llama_stack/providers/tests/datasetio/conftest.py b/llama_stack/providers/tests/datasetio/conftest.py deleted file mode 100644 index 740eddb33..000000000 --- a/llama_stack/providers/tests/datasetio/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from .fixtures import DATASETIO_FIXTURES - - -def pytest_configure(config): - for fixture_name in DATASETIO_FIXTURES: - config.addinivalue_line( - "markers", - f"{fixture_name}: marks tests as {fixture_name} specific", - ) - - -def pytest_generate_tests(metafunc): - if "datasetio_stack" in metafunc.fixturenames: - metafunc.parametrize( - "datasetio_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in DATASETIO_FIXTURES - ], - indirect=True, - ) diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py deleted file mode 100644 index 27aedb645..000000000 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import construct_stack_for_test - -from ..conftest import ProviderFixture, remote_stack_fixture - - -@pytest.fixture(scope="session") -def datasetio_remote() -> ProviderFixture: - return remote_stack_fixture() - - -@pytest.fixture(scope="session") -def datasetio_localfs() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="localfs", - provider_type="inline::localfs", - config={}, - ) - ], - ) - - -@pytest.fixture(scope="session") -def datasetio_huggingface() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="huggingface", - provider_type="remote::huggingface", - config={}, - ) - ], - ) - - -DATASETIO_FIXTURES = ["localfs", "remote", "huggingface"] - - -@pytest_asyncio.fixture(scope="session") -async def datasetio_stack(request): - fixture_name = request.param - fixture = request.getfixturevalue(f"datasetio_{fixture_name}") - - test_stack = await construct_stack_for_test( - [Api.datasetio], - {"datasetio": fixture.providers}, - fixture.provider_data, - ) - - return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py deleted file mode 100644 index fd76bafe0..000000000 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import base64 -import mimetypes -import os -from pathlib import Path - -import pytest - -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType -from llama_stack.apis.datasets import Datasets - -# How to run this test: -# -# pytest llama_stack/providers/tests/datasetio/test_datasetio.py -# -m "meta_reference" -# -v -s --tb=short --disable-warnings - - -def data_url_from_file(file_path: str) -> str: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - with open(file_path, "rb") as file: - file_content = file.read() - - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) - - data_url = f"data:{mime_type};base64,{base64_content}" - - return data_url - - -async def register_dataset( - datasets_impl: Datasets, - for_generation=False, - for_rag=False, - dataset_id="test_dataset", -): - if for_rag: - test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv" - else: - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" - test_url = data_url_from_file(str(test_file)) - - if for_generation: - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - "chat_completion_input": ChatCompletionInputType(), - } - elif for_rag: - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - "generated_answer": StringType(), - "context": StringType(), - } - else: - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - "generated_answer": StringType(), - } - - await datasets_impl.register_dataset( - dataset_id=dataset_id, - dataset_schema=dataset_schema, - url=URL(uri=test_url), - ) - - -class TestDatasetIO: - @pytest.mark.asyncio - async def test_datasets_list(self, datasetio_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - _, datasets_impl = datasetio_stack - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 - - @pytest.mark.asyncio - async def test_register_dataset(self, datasetio_stack): - _, datasets_impl = datasetio_stack - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - with pytest.raises(ValueError): - # unregister a dataset that does not exist - await datasets_impl.unregister_dataset("test_dataset2") - - await datasets_impl.unregister_dataset("test_dataset") - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 - - with pytest.raises(ValueError): - await datasets_impl.unregister_dataset("test_dataset") - - @pytest.mark.asyncio - async def test_get_rows_paginated(self, datasetio_stack): - datasetio_impl, datasets_impl = datasetio_stack - await register_dataset(datasets_impl) - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - provider = datasetio_impl.routing_table.get_provider_impl("test_dataset") - if provider.__provider_spec__.provider_type == "remote": - pytest.skip("remote provider doesn't support get_rows_paginated") - - # iterate over all rows - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py deleted file mode 100644 index c1da6ba42..000000000 --- a/llama_stack/providers/tests/eval/conftest.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from ..agents.fixtures import AGENTS_FIXTURES -from ..conftest import get_provider_fixture_overrides -from ..datasetio.fixtures import DATASETIO_FIXTURES -from ..inference.fixtures import INFERENCE_FIXTURES -from ..safety.fixtures import SAFETY_FIXTURES -from ..scoring.fixtures import SCORING_FIXTURES -from ..tools.fixtures import TOOL_RUNTIME_FIXTURES -from ..vector_io.fixtures import VECTOR_IO_FIXTURES -from .fixtures import EVAL_FIXTURES - -DEFAULT_PROVIDER_COMBINATIONS = [ - pytest.param( - { - "eval": "meta_reference", - "scoring": "basic", - "datasetio": "localfs", - "inference": "fireworks", - "agents": "meta_reference", - "safety": "llama_guard", - "vector_io": "faiss", - "tool_runtime": "memory_and_search", - }, - id="meta_reference_eval_fireworks_inference", - marks=pytest.mark.meta_reference_eval_fireworks_inference, - ), - pytest.param( - { - "eval": "meta_reference", - "scoring": "basic", - "datasetio": "localfs", - "inference": "together", - "agents": "meta_reference", - "safety": "llama_guard", - "vector_io": "faiss", - "tool_runtime": "memory_and_search", - }, - id="meta_reference_eval_together_inference", - marks=pytest.mark.meta_reference_eval_together_inference, - ), - pytest.param( - { - "eval": "meta_reference", - "scoring": "basic", - "datasetio": "huggingface", - "inference": "together", - "agents": "meta_reference", - "safety": "llama_guard", - "vector_io": "faiss", - "tool_runtime": "memory_and_search", - }, - id="meta_reference_eval_together_inference_huggingface_datasetio", - marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio, - ), -] - - -def pytest_configure(config): - for fixture_name in [ - "meta_reference_eval_fireworks_inference", - "meta_reference_eval_together_inference", - "meta_reference_eval_together_inference_huggingface_datasetio", - ]: - config.addinivalue_line( - "markers", - f"{fixture_name}: marks tests as {fixture_name} specific", - ) - - -def pytest_generate_tests(metafunc): - if "eval_stack" in metafunc.fixturenames: - available_fixtures = { - "eval": EVAL_FIXTURES, - "scoring": SCORING_FIXTURES, - "datasetio": DATASETIO_FIXTURES, - "inference": INFERENCE_FIXTURES, - "agents": AGENTS_FIXTURES, - "safety": SAFETY_FIXTURES, - "vector_io": VECTOR_IO_FIXTURES, - "tool_runtime": TOOL_RUNTIME_FIXTURES, - } - combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS - ) - metafunc.parametrize("eval_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py deleted file mode 100644 index c6d15bbf5..000000000 --- a/llama_stack/providers/tests/eval/fixtures.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.distribution.datatypes import Api, ModelInput, Provider -from llama_stack.providers.tests.resolver import construct_stack_for_test - -from ..conftest import ProviderFixture, remote_stack_fixture - - -@pytest.fixture(scope="session") -def eval_remote() -> ProviderFixture: - return remote_stack_fixture() - - -@pytest.fixture(scope="session") -def eval_meta_reference() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - config={}, - ) - ], - ) - - -EVAL_FIXTURES = ["meta_reference", "remote"] - - -@pytest_asyncio.fixture(scope="session") -async def eval_stack( - request, - inference_model, - judge_model, - tool_group_input_memory, - tool_group_input_tavily_search, -): - fixture_dict = request.param - - providers = {} - provider_data = {} - for key in [ - "datasetio", - "eval", - "scoring", - "inference", - "agents", - "safety", - "vector_io", - "tool_runtime", - ]: - fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") - providers[key] = fixture.providers - if fixture.provider_data: - provider_data.update(fixture.provider_data) - - test_stack = await construct_stack_for_test( - [ - Api.eval, - Api.datasetio, - Api.inference, - Api.scoring, - Api.agents, - Api.safety, - Api.vector_io, - Api.tool_runtime, - ], - providers, - provider_data, - models=[ - ModelInput(model_id=model) - for model in [ - inference_model, - judge_model, - ] - ], - tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], - ) - - return test_stack.impls diff --git a/llama_stack/providers/tests/post_training/conftest.py b/llama_stack/providers/tests/post_training/conftest.py deleted file mode 100644 index b6d95444b..000000000 --- a/llama_stack/providers/tests/post_training/conftest.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from ..conftest import get_provider_fixture_overrides -from ..datasetio.fixtures import DATASETIO_FIXTURES -from .fixtures import POST_TRAINING_FIXTURES - -DEFAULT_PROVIDER_COMBINATIONS = [ - pytest.param( - { - "post_training": "torchtune", - "datasetio": "huggingface", - }, - id="torchtune_post_training_huggingface_datasetio", - marks=pytest.mark.torchtune_post_training_huggingface_datasetio, - ), -] - - -def pytest_configure(config): - combined_fixtures = "torchtune_post_training_huggingface_datasetio" - config.addinivalue_line( - "markers", - f"{combined_fixtures}: marks tests as {combined_fixtures} specific", - ) - - -def pytest_generate_tests(metafunc): - if "post_training_stack" in metafunc.fixturenames: - available_fixtures = { - "eval": POST_TRAINING_FIXTURES, - "datasetio": DATASETIO_FIXTURES, - } - combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS - ) - metafunc.parametrize("post_training_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py deleted file mode 100644 index 7c3ff3ddb..000000000 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import StringType -from llama_stack.apis.datasets import DatasetInput -from llama_stack.apis.models import ModelInput -from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import construct_stack_for_test - -from ..conftest import ProviderFixture - - -@pytest.fixture(scope="session") -def post_training_torchtune() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="torchtune", - provider_type="inline::torchtune", - config={}, - ) - ], - ) - - -POST_TRAINING_FIXTURES = ["torchtune"] - - -@pytest_asyncio.fixture(scope="session") -async def post_training_stack(request): - fixture_dict = request.param - - providers = {} - provider_data = {} - for key in ["post_training", "datasetio"]: - fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") - providers[key] = fixture.providers - if fixture.provider_data: - provider_data.update(fixture.provider_data) - - test_stack = await construct_stack_for_test( - [Api.post_training, Api.datasetio], - providers, - provider_data, - models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")], - datasets=[ - DatasetInput( - dataset_id="alpaca", - provider_id="huggingface", - url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"), - metadata={ - "path": "tatsu-lab/alpaca", - "split": "train", - }, - dataset_schema={ - "instruction": StringType(), - "input": StringType(), - "output": StringType(), - "text": StringType(), - }, - ), - ], - ) - - return test_stack.impls[Api.post_training] diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py deleted file mode 100644 index 9278d3c2d..000000000 --- a/llama_stack/providers/tests/scoring/conftest.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from ..conftest import get_provider_fixture_overrides -from ..datasetio.fixtures import DATASETIO_FIXTURES -from ..inference.fixtures import INFERENCE_FIXTURES -from .fixtures import SCORING_FIXTURES - -DEFAULT_PROVIDER_COMBINATIONS = [ - pytest.param( - { - "scoring": "basic", - "datasetio": "localfs", - "inference": "together", - }, - id="basic_scoring_together_inference", - marks=pytest.mark.basic_scoring_together_inference, - ), - pytest.param( - { - "scoring": "braintrust", - "datasetio": "localfs", - "inference": "together", - }, - id="braintrust_scoring_together_inference", - marks=pytest.mark.braintrust_scoring_together_inference, - ), - pytest.param( - { - "scoring": "llm_as_judge", - "datasetio": "localfs", - "inference": "together", - }, - id="llm_as_judge_scoring_together_inference", - marks=pytest.mark.llm_as_judge_scoring_together_inference, - ), -] - - -def pytest_configure(config): - for fixture_name in [ - "basic_scoring_together_inference", - "braintrust_scoring_together_inference", - "llm_as_judge_scoring_together_inference", - ]: - config.addinivalue_line( - "markers", - f"{fixture_name}: marks tests as {fixture_name} specific", - ) - - -def pytest_generate_tests(metafunc): - judge_model = metafunc.config.getoption("--judge-model") - if "judge_model" in metafunc.fixturenames: - metafunc.parametrize( - "judge_model", - [pytest.param(judge_model, id="")], - indirect=True, - ) - - if "scoring_stack" in metafunc.fixturenames: - available_fixtures = { - "scoring": SCORING_FIXTURES, - "datasetio": DATASETIO_FIXTURES, - "inference": INFERENCE_FIXTURES, - } - combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS - ) - metafunc.parametrize("scoring_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py deleted file mode 100644 index 09f31cbc2..000000000 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.apis.models import ModelInput -from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig -from llama_stack.providers.tests.resolver import construct_stack_for_test - -from ..conftest import ProviderFixture, remote_stack_fixture -from ..env import get_env_or_fail - - -@pytest.fixture(scope="session") -def scoring_remote() -> ProviderFixture: - return remote_stack_fixture() - - -@pytest.fixture(scope="session") -def judge_model(request): - if hasattr(request, "param"): - return request.param - return request.config.getoption("--judge-model", None) - - -@pytest.fixture(scope="session") -def scoring_basic() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="basic", - provider_type="inline::basic", - config={}, - ) - ], - ) - - -@pytest.fixture(scope="session") -def scoring_braintrust() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="braintrust", - provider_type="inline::braintrust", - config=BraintrustScoringConfig( - openai_api_key=get_env_or_fail("OPENAI_API_KEY"), - ).model_dump(), - ) - ], - ) - - -@pytest.fixture(scope="session") -def scoring_llm_as_judge() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="llm-as-judge", - provider_type="inline::llm-as-judge", - config={}, - ) - ], - ) - - -SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"] - - -@pytest_asyncio.fixture(scope="session") -async def scoring_stack(request, inference_model, judge_model): - fixture_dict = request.param - - providers = {} - provider_data = {} - for key in ["datasetio", "scoring", "inference"]: - fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") - providers[key] = fixture.providers - if fixture.provider_data: - provider_data.update(fixture.provider_data) - - test_stack = await construct_stack_for_test( - [Api.scoring, Api.datasetio, Api.inference], - providers, - provider_data, - models=[ - ModelInput(model_id=model) - for model in [ - inference_model, - judge_model, - ] - ], - ) - - return test_stack.impls diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py deleted file mode 100644 index d80b105f4..000000000 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -import pytest - -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, -) -from llama_stack.distribution.datatypes import Api -from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset - -# How to run this test: -# -# pytest llama_stack/providers/tests/scoring/test_scoring.py -# -m "meta_reference" -# -v -s --tb=short --disable-warnings - - -@pytest.fixture -def sample_judge_prompt_template(): - return "Output a number response in the following format: Score: , where is the number between 0 and 9." - - -class TestScoring: - @pytest.mark.asyncio - async def test_scoring_functions_list(self, scoring_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - scoring_functions_impl = scoring_stack[Api.scoring_functions] - response = await scoring_functions_impl.list_scoring_functions() - assert isinstance(response, list) - assert len(response) > 0 - - @pytest.mark.asyncio - async def test_scoring_score(self, scoring_stack): - ( - scoring_impl, - scoring_functions_impl, - datasetio_impl, - datasets_impl, - ) = ( - scoring_stack[Api.scoring], - scoring_stack[Api.scoring_functions], - scoring_stack[Api.datasetio], - scoring_stack[Api.datasets], - ) - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() - provider_id = scoring_fns_list[0].provider_id - if provider_id == "llm-as-judge": - pytest.skip(f"{provider_id} provider does not support scoring without params") - - await register_dataset(datasets_impl, for_rag=True) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - # scoring individual rows - rows = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert len(rows.rows) == 3 - - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() - scoring_functions = { - scoring_fns_list[0].identifier: None, - } - - response = await scoring_impl.score( - input_rows=rows.rows, - scoring_functions=scoring_functions, - ) - assert len(response.results) == len(scoring_functions) - for x in scoring_functions: - assert x in response.results - assert len(response.results[x].score_rows) == len(rows.rows) - - # score batch - response = await scoring_impl.score_batch( - dataset_id="test_dataset", - scoring_functions=scoring_functions, - ) - assert len(response.results) == len(scoring_functions) - for x in scoring_functions: - assert x in response.results - assert len(response.results[x].score_rows) == 5 - - @pytest.mark.asyncio - async def test_scoring_score_with_params_llm_as_judge( - self, scoring_stack, sample_judge_prompt_template, judge_model - ): - ( - scoring_impl, - scoring_functions_impl, - datasetio_impl, - datasets_impl, - ) = ( - scoring_stack[Api.scoring], - scoring_stack[Api.scoring_functions], - scoring_stack[Api.datasetio], - scoring_stack[Api.datasets], - ) - await register_dataset(datasets_impl, for_rag=True) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() - provider_id = scoring_fns_list[0].provider_id - if provider_id == "braintrust" or provider_id == "basic": - pytest.skip(f"{provider_id} provider does not support scoring with params") - - # scoring individual rows - rows = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert len(rows.rows) == 3 - - scoring_functions = { - "llm-as-judge::base": LLMAsJudgeScoringFnParams( - judge_model=judge_model, - prompt_template=sample_judge_prompt_template, - judge_score_regexes=[r"Score: (\d+)"], - aggregation_functions=[AggregationFunctionType.categorical_count], - ) - } - - response = await scoring_impl.score( - input_rows=rows.rows, - scoring_functions=scoring_functions, - ) - assert len(response.results) == len(scoring_functions) - for x in scoring_functions: - assert x in response.results - assert len(response.results[x].score_rows) == len(rows.rows) - - # score batch - response = await scoring_impl.score_batch( - dataset_id="test_dataset", - scoring_functions=scoring_functions, - ) - assert len(response.results) == len(scoring_functions) - for x in scoring_functions: - assert x in response.results - assert len(response.results[x].score_rows) == 5 - - @pytest.mark.asyncio - async def test_scoring_score_with_aggregation_functions( - self, scoring_stack, sample_judge_prompt_template, judge_model - ): - ( - scoring_impl, - scoring_functions_impl, - datasetio_impl, - datasets_impl, - ) = ( - scoring_stack[Api.scoring], - scoring_stack[Api.scoring_functions], - scoring_stack[Api.datasetio], - scoring_stack[Api.datasets], - ) - await register_dataset(datasets_impl, for_rag=True) - rows = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert len(rows.rows) == 3 - - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() - scoring_functions = {} - aggr_fns = [ - AggregationFunctionType.accuracy, - AggregationFunctionType.median, - AggregationFunctionType.categorical_count, - AggregationFunctionType.average, - ] - for x in scoring_fns_list: - if x.provider_id == "llm-as-judge": - aggr_fns = [AggregationFunctionType.categorical_count] - scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams( - judge_model=judge_model, - prompt_template=sample_judge_prompt_template, - judge_score_regexes=[r"Score: (\d+)"], - aggregation_functions=aggr_fns, - ) - elif x.provider_id == "basic" or x.provider_id == "braintrust": - if "regex_parser" in x.identifier: - scoring_functions[x.identifier] = RegexParserScoringFnParams( - aggregation_functions=aggr_fns, - ) - else: - scoring_functions[x.identifier] = BasicScoringFnParams( - aggregation_functions=aggr_fns, - ) - else: - scoring_functions[x.identifier] = None - - response = await scoring_impl.score( - input_rows=rows.rows, - scoring_functions=scoring_functions, - ) - - assert len(response.results) == len(scoring_functions) - for x in scoring_functions: - assert x in response.results - assert len(response.results[x].score_rows) == len(rows.rows) - assert len(response.results[x].aggregated_results) == len(aggr_fns) diff --git a/llama_stack/providers/tests/tools/__init__.py b/llama_stack/providers/tests/tools/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/tests/tools/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py deleted file mode 100644 index 253ae88f0..000000000 --- a/llama_stack/providers/tests/tools/conftest.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from ..conftest import get_provider_fixture_overrides -from ..inference.fixtures import INFERENCE_FIXTURES -from ..safety.fixtures import SAFETY_FIXTURES -from ..vector_io.fixtures import VECTOR_IO_FIXTURES -from .fixtures import TOOL_RUNTIME_FIXTURES - -DEFAULT_PROVIDER_COMBINATIONS = [ - pytest.param( - { - "inference": "together", - "safety": "llama_guard", - "vector_io": "faiss", - "tool_runtime": "memory_and_search", - }, - id="together", - marks=pytest.mark.together, - ), -] - - -def pytest_configure(config): - for mark in ["together"]: - config.addinivalue_line( - "markers", - f"{mark}: marks tests as {mark} specific", - ) - - -def pytest_generate_tests(metafunc): - if "tools_stack" in metafunc.fixturenames: - available_fixtures = { - "inference": INFERENCE_FIXTURES, - "safety": SAFETY_FIXTURES, - "vector_io": VECTOR_IO_FIXTURES, - "tool_runtime": TOOL_RUNTIME_FIXTURES, - } - combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS - ) - metafunc.parametrize("tools_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py deleted file mode 100644 index ddf8e9af2..000000000 --- a/llama_stack/providers/tests/tools/fixtures.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os - -import pytest -import pytest_asyncio - -from llama_stack.apis.models import ModelInput, ModelType -from llama_stack.apis.tools import ToolGroupInput -from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import construct_stack_for_test - -from ..conftest import ProviderFixture - - -@pytest.fixture(scope="session") -def tool_runtime_memory_and_search() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="rag-runtime", - provider_type="inline::rag-runtime", - config={}, - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - config={ - "api_key": os.environ["TAVILY_SEARCH_API_KEY"], - }, - ), - Provider( - provider_id="wolfram-alpha", - provider_type="remote::wolfram-alpha", - config={ - "api_key": os.environ["WOLFRAM_ALPHA_API_KEY"], - }, - ), - ], - ) - - -@pytest.fixture(scope="session") -def tool_group_input_memory() -> ToolGroupInput: - return ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ) - - -@pytest.fixture(scope="session") -def tool_group_input_tavily_search() -> ToolGroupInput: - return ToolGroupInput( - toolgroup_id="builtin::web_search", - provider_id="tavily-search", - ) - - -@pytest.fixture(scope="session") -def tool_group_input_wolfram_alpha() -> ToolGroupInput: - return ToolGroupInput( - toolgroup_id="builtin::wolfram_alpha", - provider_id="wolfram-alpha", - ) - - -TOOL_RUNTIME_FIXTURES = ["memory_and_search"] - - -@pytest_asyncio.fixture(scope="session") -async def tools_stack( - request, - inference_model, - tool_group_input_memory, - tool_group_input_tavily_search, - tool_group_input_wolfram_alpha, -): - fixture_dict = request.param - - providers = {} - provider_data = {} - for key in ["inference", "vector_io", "tool_runtime"]: - fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") - providers[key] = fixture.providers - if key == "inference": - providers[key].append( - Provider( - provider_id="tools_memory_provider", - provider_type="inline::sentence-transformers", - config={}, - ) - ) - if fixture.provider_data: - provider_data.update(fixture.provider_data) - inference_models = inference_model if isinstance(inference_model, list) else [inference_model] - models = [ - ModelInput( - model_id=model, - model_type=ModelType.llm, - provider_id=providers["inference"][0].provider_id, - ) - for model in inference_models - ] - models.append( - ModelInput( - model_id="all-MiniLM-L6-v2", - model_type=ModelType.embedding, - provider_id="tools_memory_provider", - metadata={"embedding_dimension": 384}, - ) - ) - - test_stack = await construct_stack_for_test( - [ - Api.tool_groups, - Api.inference, - Api.vector_io, - Api.tool_runtime, - ], - providers, - provider_data, - models=models, - tool_groups=[ - tool_group_input_tavily_search, - tool_group_input_wolfram_alpha, - tool_group_input_memory, - ], - ) - return test_stack diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py deleted file mode 100644 index 8188f3dd7..000000000 --- a/llama_stack/providers/tests/tools/test_tools.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os - -import pytest - -from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult -from llama_stack.providers.datatypes import Api - - -@pytest.fixture -def sample_search_query(): - return "What are the latest developments in quantum computing?" - - -@pytest.fixture -def sample_wolfram_alpha_query(): - return "What is the square root of 16?" - - -@pytest.fixture -def sample_documents(): - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - "qat_finetune.rst", - "lora_finetune.rst", - ] - return [ - RAGDocument( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", - metadata={}, - ) - for i, url in enumerate(urls) - ] - - -class TestTools: - @pytest.mark.asyncio - async def test_web_search_tool(self, tools_stack, sample_search_query): - """Test the web search tool functionality.""" - if "TAVILY_SEARCH_API_KEY" not in os.environ: - pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - - tools_impl = tools_stack.impls[Api.tool_runtime] - - # Execute the tool - response = await tools_impl.invoke_tool(tool_name="web_search", kwargs={"query": sample_search_query}) - - # Verify the response - assert isinstance(response, ToolInvocationResult) - assert response.content is not None - assert len(response.content) > 0 - assert isinstance(response.content, str) - - @pytest.mark.asyncio - async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query): - """Test the wolfram alpha tool functionality.""" - if "WOLFRAM_ALPHA_API_KEY" not in os.environ: - pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") - - tools_impl = tools_stack.impls[Api.tool_runtime] - - response = await tools_impl.invoke_tool(tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}) - - # Verify the response - assert isinstance(response, ToolInvocationResult) - assert response.content is not None - assert len(response.content) > 0 - assert isinstance(response.content, str) - - @pytest.mark.asyncio - async def test_rag_tool(self, tools_stack, sample_documents): - """Test the memory tool functionality.""" - vector_dbs_impl = tools_stack.impls[Api.vector_dbs] - tools_impl = tools_stack.impls[Api.tool_runtime] - - # Register memory bank - await vector_dbs_impl.register_vector_db( - vector_db_id="test_bank", - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - provider_id="faiss", - ) - - # Insert documents into memory - await tools_impl.rag_tool.insert( - documents=sample_documents, - vector_db_id="test_bank", - chunk_size_in_tokens=512, - ) - - # Execute the memory tool - response = await tools_impl.rag_tool.query( - content="What are the main topics covered in the documentation?", - vector_db_ids=["test_bank"], - ) - - # Verify the response - assert isinstance(response, RAGQueryResult) - assert response.content is not None - assert len(response.content) > 0 diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index a9c472c53..3907eba78 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -27,6 +27,7 @@ distribution_spec: tool_runtime: - remote::brave-search - remote::tavily-search + - remote::wolfram-alpha - inline::code-interpreter - inline::rag-runtime - remote::model-context-protocol diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index 2baab9d7c..3e6d1ca89 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate: "tool_runtime": [ "remote::brave-search", "remote::tavily-search", + "remote::wolfram-alpha", "inline::code-interpreter", "inline::rag-runtime", "remote::model-context-protocol", @@ -77,6 +78,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ToolGroupInput( toolgroup_id="builtin::rag", provider_id="rag-runtime", diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 0fe5f3026..359bf0194 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -86,6 +86,9 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:} max_results: 3 + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} - provider_id: code-interpreter provider_type: inline::code-interpreter config: {} @@ -225,6 +228,8 @@ benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha - toolgroup_id: builtin::rag provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index cbe85c4f7..0ce3a4505 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -80,6 +80,9 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:} max_results: 3 + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} - provider_id: code-interpreter provider_type: inline::code-interpreter config: {} @@ -214,6 +217,8 @@ benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha - toolgroup_id: builtin::rag provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index da33b8d53..58bd8e854 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -29,4 +29,5 @@ distribution_spec: - inline::code-interpreter - inline::rag-runtime - remote::model-context-protocol + - remote::wolfram-alpha image_type: conda diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 2345bf3e5..16d8a259f 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -34,6 +34,7 @@ def get_distribution_template() -> DistributionTemplate: "inline::code-interpreter", "inline::rag-runtime", "remote::model-context-protocol", + "remote::wolfram-alpha", ], } name = "ollama" @@ -78,6 +79,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::code_interpreter", provider_id="code-interpreter", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ] return DistributionTemplate( diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index d5766dec1..c8d5a22a4 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -85,6 +85,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db @@ -119,5 +122,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index a2428688e..fa21170d2 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -82,6 +82,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db @@ -108,5 +111,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index ccb328c1c..b2bbf853a 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - inline::code-interpreter - inline::rag-runtime - remote::model-context-protocol + - remote::wolfram-alpha image_type: conda diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index dd43f21f6..45af8427a 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -96,6 +96,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db @@ -126,5 +129,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 24cd207c7..674085045 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -90,6 +90,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db @@ -115,5 +118,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 16bf1d0fa..9901fc83b 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -37,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate: "inline::code-interpreter", "inline::rag-runtime", "remote::model-context-protocol", + "remote::wolfram-alpha", ], } name = "remote-vllm" @@ -87,6 +88,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::code_interpreter", provider_id="code-interpreter", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ] return DistributionTemplate( diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index a8a6de28d..834a3ecaf 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - inline::code-interpreter - inline::rag-runtime - remote::model-context-protocol + - remote::wolfram-alpha image_type: conda diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 26d879802..fd74f80c3 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -95,6 +95,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db @@ -226,5 +229,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 0969cfe56..9a717290a 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -89,6 +89,9 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db @@ -215,5 +218,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index bf6f0cea4..fce03a1b2 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -38,6 +38,7 @@ def get_distribution_template() -> DistributionTemplate: "inline::code-interpreter", "inline::rag-runtime", "remote::model-context-protocol", + "remote::wolfram-alpha", ], } name = "together" @@ -73,6 +74,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::code_interpreter", provider_id="code-interpreter", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ] embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8e0cbdf65..dada5449f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -20,7 +20,7 @@ from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.utils.dynamic import instantiate_class_type -from llama_stack.providers.tests.env import get_env_or_fail +from llama_stack.env import get_env_or_fail from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from .fixtures.recordable_mock import RecordableMock @@ -84,6 +84,11 @@ def pytest_addoption(parser): default=None, help="Specify the embedding model to use for testing", ) + parser.addoption( + "--judge-model", + default=None, + help="Specify the judge model to use for testing", + ) parser.addoption( "--embedding-dimension", type=int, @@ -109,6 +114,7 @@ def provider_data(): "TOGETHER_API_KEY": "together_api_key", "ANTHROPIC_API_KEY": "anthropic_api_key", "GROQ_API_KEY": "groq_api_key", + "WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key", } provider_data = {} for key, value in keymap.items(): @@ -260,7 +266,9 @@ def inference_provider_type(llama_stack_client): @pytest.fixture(scope="session") -def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension): +def client_with_models( + llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id +): client = llama_stack_client providers = [p for p in client.providers.list() if p.api == "inference"] @@ -274,6 +282,8 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) if vision_model_id and vision_model_id not in model_ids: client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) + if judge_model_id and judge_model_id not in model_ids: + client.models.register(model_id=judge_model_id, provider_id=inference_providers[0]) if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids: # try to find a provider that supports embeddings, if sentence-transformers is not available @@ -328,6 +338,14 @@ def pytest_generate_tests(metafunc): if val is not None: id_parts.append(f"emb={get_short_id(val)}") + if "judge_model_id" in metafunc.fixturenames: + params.append("judge_model_id") + val = metafunc.config.getoption("--judge-model") + print(f"judge_model_id: {val}") + values.append(val) + if val is not None: + id_parts.append(f"judge={get_short_id(val)}") + if "embedding_dimension" in metafunc.fixturenames: params.append("embedding_dimension") val = metafunc.config.getoption("--embedding-dimension") diff --git a/llama_stack/providers/tests/datasetio/__init__.py b/tests/integration/datasetio/__init__.py similarity index 100% rename from llama_stack/providers/tests/datasetio/__init__.py rename to tests/integration/datasetio/__init__.py diff --git a/llama_stack/providers/tests/datasetio/test_dataset.csv b/tests/integration/datasetio/test_dataset.csv similarity index 100% rename from llama_stack/providers/tests/datasetio/test_dataset.csv rename to tests/integration/datasetio/test_dataset.csv diff --git a/tests/integration/datasetio/test_datasetio.py b/tests/integration/datasetio/test_datasetio.py new file mode 100644 index 000000000..899cb8c43 --- /dev/null +++ b/tests/integration/datasetio/test_datasetio.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os +from pathlib import Path + +import pytest + +# How to run this test: +# +# pytest llama_stack/providers/tests/datasetio/test_datasetio.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings + + +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + +def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"): + if for_rag: + test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv" + else: + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) + + if for_generation: + dataset_schema = { + "expected_answer": {"type": "string"}, + "input_query": {"type": "string"}, + "chat_completion_input": {"type": "chat_completion_input"}, + } + elif for_rag: + dataset_schema = { + "expected_answer": {"type": "string"}, + "input_query": {"type": "string"}, + "generated_answer": {"type": "string"}, + "context": {"type": "string"}, + } + else: + dataset_schema = { + "expected_answer": {"type": "string"}, + "input_query": {"type": "string"}, + "generated_answer": {"type": "string"}, + } + + llama_stack_client.datasets.register( + dataset_id=dataset_id, + dataset_schema=dataset_schema, + url=dict(uri=test_url), + provider_id="localfs", + ) + + +def test_datasets_list(llama_stack_client): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + + response = llama_stack_client.datasets.list() + assert isinstance(response, list) + assert len(response) == 0 + + +def test_register_dataset(llama_stack_client): + register_dataset(llama_stack_client) + response = llama_stack_client.datasets.list() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" + + with pytest.raises(ValueError): + # unregister a dataset that does not exist + llama_stack_client.datasets.unregister("test_dataset2") + + llama_stack_client.datasets.unregister("test_dataset") + response = llama_stack_client.datasets.list() + assert isinstance(response, list) + assert len(response) == 0 + + with pytest.raises(ValueError): + llama_stack_client.datasets.unregister("test_dataset") + + +def test_get_rows_paginated(llama_stack_client): + register_dataset(llama_stack_client) + response = llama_stack_client.datasetio.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" + + # iterate over all rows + response = llama_stack_client.datasetio.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/datasetio/test_rag_dataset.csv b/tests/integration/datasetio/test_rag_dataset.csv similarity index 100% rename from llama_stack/providers/tests/datasetio/test_rag_dataset.csv rename to tests/integration/datasetio/test_rag_dataset.csv diff --git a/llama_stack/providers/tests/eval/__init__.py b/tests/integration/eval/__init__.py similarity index 100% rename from llama_stack/providers/tests/eval/__init__.py rename to tests/integration/eval/__init__.py diff --git a/llama_stack/providers/tests/eval/constants.py b/tests/integration/eval/constants.py similarity index 100% rename from llama_stack/providers/tests/eval/constants.py rename to tests/integration/eval/constants.py diff --git a/llama_stack/providers/tests/eval/test_eval.py b/tests/integration/eval/test_eval.py similarity index 95% rename from llama_stack/providers/tests/eval/test_eval.py rename to tests/integration/eval/test_eval.py index 4470ffe4c..a7d59a2de 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/tests/integration/eval/test_eval.py @@ -10,15 +10,13 @@ import pytest from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( - AppBenchmarkConfig, - BenchmarkBenchmarkConfig, ModelCandidate, ) from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api -from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset +from ..datasetio.test_datasetio import register_dataset from .constants import JUDGE_PROMPT # How to run this test: @@ -28,6 +26,7 @@ from .constants import JUDGE_PROMPT # -v -s --tb=short --disable-warnings +@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API") class Testeval: @pytest.mark.asyncio async def test_benchmarks_list(self, eval_stack): @@ -68,7 +67,7 @@ class Testeval: benchmark_id=benchmark_id, input_rows=rows.rows, scoring_functions=scoring_functions, - benchmark_config=AppBenchmarkConfig( + benchmark_config=dict( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), @@ -111,7 +110,7 @@ class Testeval: ) response = await eval_impl.run_eval( benchmark_id=benchmark_id, - benchmark_config=AppBenchmarkConfig( + benchmark_config=dict( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), @@ -169,7 +168,7 @@ class Testeval: benchmark_id = "meta-reference-mmlu" response = await eval_impl.run_eval( benchmark_id=benchmark_id, - benchmark_config=BenchmarkBenchmarkConfig( + benchmark_config=dict( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), diff --git a/llama_stack/providers/tests/post_training/__init__.py b/tests/integration/post_training/__init__.py similarity index 100% rename from llama_stack/providers/tests/post_training/__init__.py rename to tests/integration/post_training/__init__.py diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py similarity index 97% rename from llama_stack/providers/tests/post_training/test_post_training.py rename to tests/integration/post_training/test_post_training.py index aefef5332..3e22bc5a7 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -26,6 +26,7 @@ from llama_stack.apis.post_training import ( # -v -s --tb=short --disable-warnings +@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API") class TestPostTraining: @pytest.mark.asyncio async def test_supervised_fine_tune(self, post_training_stack): diff --git a/tests/integration/report.py b/tests/integration/report.py index 762a7afcb..fd6c4f7a8 100644 --- a/tests/integration/report.py +++ b/tests/integration/report.py @@ -16,6 +16,7 @@ import pytest from pytest import CollectReport from termcolor import cprint +from llama_stack.env import get_env_or_fail from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import ( all_registered_models, @@ -26,7 +27,6 @@ from llama_stack.models.llama.sku_list import ( safety_models, ) from llama_stack.providers.datatypes import Api -from llama_stack.providers.tests.env import get_env_or_fail from .metadata import API_MAPS diff --git a/llama_stack/providers/tests/scoring/__init__.py b/tests/integration/scoring/__init__.py similarity index 100% rename from llama_stack/providers/tests/scoring/__init__.py rename to tests/integration/scoring/__init__.py diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py new file mode 100644 index 000000000..b695c2ef7 --- /dev/null +++ b/tests/integration/scoring/test_scoring.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from ..datasetio.test_datasetio import register_dataset + + +@pytest.fixture +def sample_judge_prompt_template(): + return "Output a number response in the following format: Score: , where is the number between 0 and 9." + + +def test_scoring_functions_list(llama_stack_client): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + response = llama_stack_client.scoring_functions.list() + assert isinstance(response, list) + assert len(response) > 0 + + +def test_scoring_score(llama_stack_client): + register_dataset(llama_stack_client, for_rag=True) + response = llama_stack_client.datasets.list() + assert len(response) == 1 + + # scoring individual rows + rows = llama_stack_client.datasetio.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_fns_list = llama_stack_client.scoring_functions.list() + scoring_functions = { + scoring_fns_list[0].identifier: None, + } + + response = llama_stack_client.scoring.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) + + # score batch + response = llama_stack_client.scoring.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + save_results_dataset=False, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5 + + +def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id): + register_dataset(llama_stack_client, for_rag=True) + response = llama_stack_client.datasets.list() + assert len(response) == 1 + + # scoring individual rows + rows = llama_stack_client.datasetio.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_functions = { + "llm-as-judge::base": dict( + type="llm_as_judge", + judge_model=judge_model_id, + prompt_template=sample_judge_prompt_template, + judge_score_regexes=[r"Score: (\d+)"], + aggregation_functions=[ + "categorical_count", + ], + ) + } + + response = llama_stack_client.scoring.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) + + # score batch + response = llama_stack_client.scoring.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + save_results_dataset=False, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5 + + +@pytest.mark.skip(reason="Skipping because this seems to be really slow") +def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id): + register_dataset(llama_stack_client, for_rag=True) + rows = llama_stack_client.datasetio.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_fns_list = llama_stack_client.scoring_functions.list() + scoring_functions = {} + aggr_fns = [ + "accuracy", + "median", + "categorical_count", + "average", + ] + for x in scoring_fns_list: + if x.provider_id == "llm-as-judge": + aggr_fns = ["categorical_count"] + scoring_functions[x.identifier] = dict( + type="llm_as_judge", + judge_model=judge_model_id, + prompt_template=sample_judge_prompt_template, + judge_score_regexes=[r"Score: (\d+)"], + aggregation_functions=aggr_fns, + ) + elif x.provider_id == "basic" or x.provider_id == "braintrust": + if "regex_parser" in x.identifier: + scoring_functions[x.identifier] = dict( + type="regex_parser", + parsing_regexes=[r"Score: (\d+)"], + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = dict( + type="basic", + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = None + + response = llama_stack_client.scoring.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) + assert len(response.results[x].aggregated_results) == len(aggr_fns) diff --git a/tests/integration/tool_runtime/test_builtin_tools.py b/tests/integration/tool_runtime/test_builtin_tools.py new file mode 100644 index 000000000..9edf3afa0 --- /dev/null +++ b/tests/integration/tool_runtime/test_builtin_tools.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os + +import pytest + + +@pytest.fixture +def sample_search_query(): + return "What are the latest developments in quantum computing?" + + +@pytest.fixture +def sample_wolfram_alpha_query(): + return "What is the square root of 16?" + + +def test_web_search_tool(llama_stack_client, sample_search_query): + """Test the web search tool functionality.""" + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + response = llama_stack_client.tool_runtime.invoke_tool( + tool_name="web_search", kwargs={"query": sample_search_query} + ) + + # Verify the response + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + + content = json.loads(response.content) + assert "query" in content + assert "top_k" in content + assert len(content["top_k"]) > 0 + + first = content["top_k"][0] + assert "title" in first + assert "url" in first + + +def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query): + """Test the wolfram alpha tool functionality.""" + if "WOLFRAM_ALPHA_API_KEY" not in os.environ: + pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") + + response = llama_stack_client.tool_runtime.invoke_tool( + tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} + ) + + print(response.content) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + + content = json.loads(response.content) + result = content["queryresult"] + assert "success" in result + assert result["success"] + assert "pods" in result + assert len(result["pods"]) > 0 diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index e330a10f5..c49f507a8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -4,29 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import random - import pytest from llama_stack_client.types import Document @pytest.fixture(scope="function") -def empty_vector_db_registry(llama_stack_client): - vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] - for vector_db_id in vector_dbs: - llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id) +def client_with_empty_registry(client_with_models): + def clear_registry(): + vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()] + for vector_db_id in vector_dbs: + client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) + clear_registry() + yield client_with_models -@pytest.fixture(scope="function") -def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): - vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" - llama_stack_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - ) - vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] - return vector_dbs + # you must clean after the last test if you were running tests against + # a stateful server instance + clear_registry() @pytest.fixture(scope="session") @@ -63,9 +57,15 @@ def assert_valid_response(response): assert isinstance(chunk.content, str) -def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vector_db_registry, sample_documents): - vector_db_id = single_entry_vector_db_registry[0] - llama_stack_client.tool_runtime.rag_tool.insert( +def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id): + vector_db_id = "test_vector_db" + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + client_with_empty_registry.tool_runtime.rag_tool.insert( documents=sample_documents, chunk_size_in_tokens=512, vector_db_id=vector_db_id, @@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect # Query with a direct match query1 = "programming language" - response1 = llama_stack_client.vector_io.query( + response1 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query=query1, ) @@ -82,7 +82,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect # Query with semantic similarity query2 = "AI and brain-inspired computing" - response2 = llama_stack_client.vector_io.query( + response2 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query=query2, ) @@ -91,7 +91,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect # Query with limit on number of results (max_chunks=2) query3 = "computer" - response3 = llama_stack_client.vector_io.query( + response3 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query=query3, params={"max_chunks": 2}, @@ -101,7 +101,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect # Query with threshold on similarity score query4 = "computer" - response4 = llama_stack_client.vector_io.query( + response4 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query=query4, params={"score_threshold": 0.01}, @@ -110,20 +110,20 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect assert all(score >= 0.01 for score in response4.scores) -def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db_registry): - providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"] +def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] assert len(providers) > 0 vector_db_id = "test_vector_db" - llama_stack_client.vector_dbs.register( + client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model_id, embedding_dimension=384, ) # list to check memory bank is successfully registered - available_vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert vector_db_id in available_vector_dbs # URLs of documents to insert @@ -144,14 +144,14 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db for i, url in enumerate(urls) ] - llama_stack_client.tool_runtime.rag_tool.insert( + client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, chunk_size_in_tokens=512, ) # Query for the name of method - response1 = llama_stack_client.vector_io.query( + response1 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query="What's the name of the fine-tunning method used?", ) @@ -159,7 +159,7 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db assert any("lora" in chunk.content.lower() for chunk in response1.chunks) # Query for the name of model - response2 = llama_stack_client.vector_io.query( + response2 = client_with_empty_registry.vector_io.query( vector_db_id=vector_db_id, query="Which Llama model is mentioned?", )