refactor(test): move tools, evals, datasetio, scoring and post training tests (#1401)

All of the tests from `llama_stack/providers/tests/` are now moved to
`tests/integration`.

I converted the `tools`, `scoring` and `datasetio` tests to use API.
However, `eval` and `post_training` proved to be a bit challenging to
leaving those. I think `post_training` should be relatively
straightforward also.

As part of this, I noticed that `wolfram_alpha` tool wasn't added to
some of our commonly used distros so I added it. I am going to remove a
lot of code duplication from distros next so while this looks like a
one-off right now, it will go away and be there uniformly for all
distros.
This commit is contained in:
Ashwin Bharambe 2025-03-04 14:53:47 -08:00 committed by GitHub
parent dd0db8038b
commit abfbaf3c1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 471 additions and 1245 deletions

View file

@ -22,7 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |

View file

@ -21,7 +21,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -22,7 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -366,7 +366,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
else: else:
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}

View file

@ -1,29 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from .fixtures import DATASETIO_FIXTURES
def pytest_configure(config):
for fixture_name in DATASETIO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
if "datasetio_stack" in metafunc.fixturenames:
metafunc.parametrize(
"datasetio_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in DATASETIO_FIXTURES
],
indirect=True,
)

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def datasetio_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def datasetio_localfs() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="localfs",
provider_type="inline::localfs",
config={},
)
],
)
@pytest.fixture(scope="session")
def datasetio_huggingface() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="huggingface",
provider_type="remote::huggingface",
config={},
)
],
)
DATASETIO_FIXTURES = ["localfs", "remote", "huggingface"]
@pytest_asyncio.fixture(scope="session")
async def datasetio_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.datasetio],
{"datasetio": fixture.providers},
fixture.provider_data,
)
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]

View file

@ -1,134 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasets import Datasets
# How to run this test:
#
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
async def register_dataset(
datasets_impl: Datasets,
for_generation=False,
for_rag=False,
dataset_id="test_dataset",
):
if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
elif for_rag:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
"context": StringType(),
}
else:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
}
await datasets_impl.register_dataset(
dataset_id=dataset_id,
dataset_schema=dataset_schema,
url=URL(uri=test_url),
)
class TestDatasetIO:
@pytest.mark.asyncio
async def test_datasets_list(self, datasetio_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, datasets_impl = datasetio_stack
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_register_dataset(self, datasetio_stack):
_, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(ValueError):
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(ValueError):
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
provider = datasetio_impl.routing_table.get_provider_impl("test_dataset")
if provider.__provider_spec__.provider_type == "remote":
pytest.skip("remote provider doesn't support get_rows_paginated")
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -1,92 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..agents.fixtures import AGENTS_FIXTURES
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import EVAL_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
"agents": "meta_reference",
"safety": "llama_guard",
"vector_io": "faiss",
"tool_runtime": "memory_and_search",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"vector_io": "faiss",
"tool_runtime": "memory_and_search",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"vector_io": "faiss",
"tool_runtime": "memory_and_search",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
"meta_reference_eval_together_inference_huggingface_datasetio",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
if "eval_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": EVAL_FIXTURES,
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
"agents": AGENTS_FIXTURES,
"safety": SAFETY_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("eval_stack", combinations, indirect=True)

View file

@ -1,87 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, ModelInput, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def eval_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def eval_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config={},
)
],
)
EVAL_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def eval_stack(
request,
inference_model,
judge_model,
tool_group_input_memory,
tool_group_input_tavily_search,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in [
"datasetio",
"eval",
"scoring",
"inference",
"agents",
"safety",
"vector_io",
"tool_runtime",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[
Api.eval,
Api.datasetio,
Api.inference,
Api.scoring,
Api.agents,
Api.safety,
Api.vector_io,
Api.tool_runtime,
],
providers,
provider_data,
models=[
ModelInput(model_id=model)
for model in [
inference_model,
judge_model,
]
],
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
)
return test_stack.impls

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from .fixtures import POST_TRAINING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"post_training": "torchtune",
"datasetio": "huggingface",
},
id="torchtune_post_training_huggingface_datasetio",
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
),
]
def pytest_configure(config):
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
config.addinivalue_line(
"markers",
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
)
def pytest_generate_tests(metafunc):
if "post_training_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": POST_TRAINING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("post_training_stack", combinations, indirect=True)

View file

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def post_training_torchtune() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="torchtune",
provider_type="inline::torchtune",
config={},
)
],
)
POST_TRAINING_FIXTURES = ["torchtune"]
@pytest_asyncio.fixture(scope="session")
async def post_training_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["post_training", "datasetio"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.post_training, Api.datasetio],
providers,
provider_data,
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
datasets=[
DatasetInput(
dataset_id="alpaca",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
metadata={
"path": "tatsu-lab/alpaca",
"split": "train",
},
dataset_schema={
"instruction": StringType(),
"input": StringType(),
"output": StringType(),
"text": StringType(),
},
),
],
)
return test_stack.impls[Api.post_training]

View file

@ -1,75 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
},
id="basic_scoring_together_inference",
marks=pytest.mark.basic_scoring_together_inference,
),
pytest.param(
{
"scoring": "braintrust",
"datasetio": "localfs",
"inference": "together",
},
id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference,
),
pytest.param(
{
"scoring": "llm_as_judge",
"datasetio": "localfs",
"inference": "together",
},
id="llm_as_judge_scoring_together_inference",
marks=pytest.mark.llm_as_judge_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"basic_scoring_together_inference",
"braintrust_scoring_together_inference",
"llm_as_judge_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
judge_model = metafunc.config.getoption("--judge-model")
if "judge_model" in metafunc.fixturenames:
metafunc.parametrize(
"judge_model",
[pytest.param(judge_model, id="")],
indirect=True,
)
if "scoring_stack" in metafunc.fixturenames:
available_fixtures = {
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("scoring_stack", combinations, indirect=True)

View file

@ -1,100 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def scoring_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def judge_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--judge-model", None)
@pytest.fixture(scope="session")
def scoring_basic() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="basic",
provider_type="inline::basic",
config={},
)
],
)
@pytest.fixture(scope="session")
def scoring_braintrust() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
config=BraintrustScoringConfig(
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def scoring_llm_as_judge() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
config={},
)
],
)
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request, inference_model, judge_model):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
models=[
ModelInput(model_id=model)
for model in [
inference_model,
judge_model,
]
],
)
return test_stack.impls

View file

@ -1,213 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
# How to run this test:
#
# pytest llama_stack/providers/tests/scoring/test_scoring.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
scoring_functions_impl = scoring_stack[Api.scoring_functions]
response = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list)
assert len(response) > 0
@pytest.mark.asyncio
async def test_scoring_score(self, scoring_stack):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "llm-as-judge":
pytest.skip(f"{provider_id} provider does not support scoring without params")
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = {
scoring_fns_list[0].identifier: None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params_llm_as_judge(
self, scoring_stack, sample_judge_prompt_template, judge_model
):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
)
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
assert len(response) == 1
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust" or provider_id == "basic":
pytest.skip(f"{provider_id} provider does not support scoring with params")
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_functions = {
"llm-as-judge::base": LLMAsJudgeScoringFnParams(
judge_model=judge_model,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=[AggregationFunctionType.categorical_count],
)
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_aggregation_functions(
self, scoring_stack, sample_judge_prompt_template, judge_model
):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
)
await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = {}
aggr_fns = [
AggregationFunctionType.accuracy,
AggregationFunctionType.median,
AggregationFunctionType.categorical_count,
AggregationFunctionType.average,
]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge":
aggr_fns = [AggregationFunctionType.categorical_count]
scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
judge_model=judge_model,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif x.provider_id == "basic" or x.provider_id == "braintrust":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = BasicScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = None
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
assert len(response.results[x].aggregated_results) == len(aggr_fns)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "llama_guard",
"vector_io": "faiss",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
),
]
def pytest_configure(config):
for mark in ["together"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",
)
def pytest_generate_tests(metafunc):
if "tools_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("tools_stack", combinations, indirect=True)

View file

@ -1,133 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
config={},
),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
Provider(
provider_id="wolfram-alpha",
provider_type="remote::wolfram-alpha",
config={
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
},
),
],
)
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
)
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::web_search",
provider_id="tavily-search",
)
@pytest.fixture(scope="session")
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(
request,
inference_model,
tool_group_input_memory,
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="tools_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = inference_model if isinstance(inference_model, list) else [inference_model]
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="tools_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test(
[
Api.tool_groups,
Api.inference,
Api.vector_io,
Api.tool_runtime,
],
providers,
provider_data,
models=models,
tool_groups=[
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
tool_group_input_memory,
],
)
return test_stack

View file

@ -1,109 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
from llama_stack.providers.datatypes import Api
@pytest.fixture
def sample_search_query():
return "What are the latest developments in quantum computing?"
@pytest.fixture
def sample_wolfram_alpha_query():
return "What is the square root of 16?"
@pytest.fixture
def sample_documents():
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
return [
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
class TestTools:
@pytest.mark.asyncio
async def test_web_search_tool(self, tools_stack, sample_search_query):
"""Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
# Execute the tool
response = await tools_impl.invoke_tool(tool_name="web_search", kwargs={"query": sample_search_query})
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
"""Test the wolfram alpha tool functionality."""
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool(tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query})
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_rag_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality."""
vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank
await vector_dbs_impl.register_vector_db(
vector_db_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
# Insert documents into memory
await tools_impl.rag_tool.insert(
documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
)
# Execute the memory tool
response = await tools_impl.rag_tool.query(
content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"],
)
# Verify the response
assert isinstance(response, RAGQueryResult)
assert response.content is not None
assert len(response.content) > 0

View file

@ -27,6 +27,7 @@ distribution_spec:
tool_runtime: tool_runtime:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search
- remote::wolfram-alpha
- inline::code-interpreter - inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol

View file

@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
"remote::wolfram-alpha",
"inline::code-interpreter", "inline::code-interpreter",
"inline::rag-runtime", "inline::rag-runtime",
"remote::model-context-protocol", "remote::model-context-protocol",
@ -77,6 +78,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::rag", toolgroup_id="builtin::rag",
provider_id="rag-runtime", provider_id="rag-runtime",

View file

@ -86,6 +86,9 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:} api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3 max_results: 3
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
- provider_id: code-interpreter - provider_id: code-interpreter
provider_type: inline::code-interpreter provider_type: inline::code-interpreter
config: {} config: {}
@ -225,6 +228,8 @@ benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter

View file

@ -80,6 +80,9 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:} api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3 max_results: 3
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
- provider_id: code-interpreter - provider_id: code-interpreter
provider_type: inline::code-interpreter provider_type: inline::code-interpreter
config: {} config: {}
@ -214,6 +217,8 @@ benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter

View file

@ -29,4 +29,5 @@ distribution_spec:
- inline::code-interpreter - inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha
image_type: conda image_type: conda

View file

@ -34,6 +34,7 @@ def get_distribution_template() -> DistributionTemplate:
"inline::code-interpreter", "inline::code-interpreter",
"inline::rag-runtime", "inline::rag-runtime",
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha",
], ],
} }
name = "ollama" name = "ollama"
@ -78,6 +79,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::code_interpreter", toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter", provider_id="code-interpreter",
), ),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -85,6 +85,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
@ -119,5 +122,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -82,6 +82,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
@ -108,5 +111,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -30,4 +30,5 @@ distribution_spec:
- inline::code-interpreter - inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha
image_type: conda image_type: conda

View file

@ -96,6 +96,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
@ -126,5 +129,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -90,6 +90,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
@ -115,5 +118,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -37,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
"inline::code-interpreter", "inline::code-interpreter",
"inline::rag-runtime", "inline::rag-runtime",
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha",
], ],
} }
name = "remote-vllm" name = "remote-vllm"
@ -87,6 +88,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::code_interpreter", toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter", provider_id="code-interpreter",
), ),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -30,4 +30,5 @@ distribution_spec:
- inline::code-interpreter - inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha
image_type: conda image_type: conda

View file

@ -95,6 +95,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
@ -226,5 +229,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -89,6 +89,9 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
@ -215,5 +218,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter - toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server: server:
port: 8321 port: 8321

View file

@ -38,6 +38,7 @@ def get_distribution_template() -> DistributionTemplate:
"inline::code-interpreter", "inline::code-interpreter",
"inline::rag-runtime", "inline::rag-runtime",
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha",
], ],
} }
name = "together" name = "together"
@ -73,6 +74,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::code_interpreter", toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter", provider_id="code-interpreter",
), ),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
] ]
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", model_id="all-MiniLM-L6-v2",

View file

@ -20,7 +20,7 @@ from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.env import get_env_or_fail
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures.recordable_mock import RecordableMock from .fixtures.recordable_mock import RecordableMock
@ -84,6 +84,11 @@ def pytest_addoption(parser):
default=None, default=None,
help="Specify the embedding model to use for testing", help="Specify the embedding model to use for testing",
) )
parser.addoption(
"--judge-model",
default=None,
help="Specify the judge model to use for testing",
)
parser.addoption( parser.addoption(
"--embedding-dimension", "--embedding-dimension",
type=int, type=int,
@ -109,6 +114,7 @@ def provider_data():
"TOGETHER_API_KEY": "together_api_key", "TOGETHER_API_KEY": "together_api_key",
"ANTHROPIC_API_KEY": "anthropic_api_key", "ANTHROPIC_API_KEY": "anthropic_api_key",
"GROQ_API_KEY": "groq_api_key", "GROQ_API_KEY": "groq_api_key",
"WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key",
} }
provider_data = {} provider_data = {}
for key, value in keymap.items(): for key, value in keymap.items():
@ -260,7 +266,9 @@ def inference_provider_type(llama_stack_client):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension): def client_with_models(
llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id
):
client = llama_stack_client client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"] providers = [p for p in client.providers.list() if p.api == "inference"]
@ -274,6 +282,8 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id and vision_model_id not in model_ids: if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids: if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available # try to find a provider that supports embeddings, if sentence-transformers is not available
@ -328,6 +338,14 @@ def pytest_generate_tests(metafunc):
if val is not None: if val is not None:
id_parts.append(f"emb={get_short_id(val)}") id_parts.append(f"emb={get_short_id(val)}")
if "judge_model_id" in metafunc.fixturenames:
params.append("judge_model_id")
val = metafunc.config.getoption("--judge-model")
print(f"judge_model_id: {val}")
values.append(val)
if val is not None:
id_parts.append(f"judge={get_short_id(val)}")
if "embedding_dimension" in metafunc.fixturenames: if "embedding_dimension" in metafunc.fixturenames:
params.append("embedding_dimension") params.append("embedding_dimension")
val = metafunc.config.getoption("--embedding-dimension") val = metafunc.config.getoption("--embedding-dimension")

View file

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
# How to run this test:
#
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
elif for_rag:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
"context": {"type": "string"},
}
else:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
}
llama_stack_client.datasets.register(
dataset_id=dataset_id,
dataset_schema=dataset_schema,
url=dict(uri=test_url),
provider_id="localfs",
)
def test_datasets_list(llama_stack_client):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 0
def test_register_dataset(llama_stack_client):
register_dataset(llama_stack_client)
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(ValueError):
# unregister a dataset that does not exist
llama_stack_client.datasets.unregister("test_dataset2")
llama_stack_client.datasets.unregister("test_dataset")
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(ValueError):
llama_stack_client.datasets.unregister("test_dataset")
def test_get_rows_paginated(llama_stack_client):
register_dataset(llama_stack_client)
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# iterate over all rows
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -10,15 +10,13 @@ import pytest
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import ( from llama_stack.apis.eval.eval import (
AppBenchmarkConfig,
BenchmarkBenchmarkConfig,
ModelCandidate, ModelCandidate,
) )
from llama_stack.apis.inference import SamplingParams from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from ..datasetio.test_datasetio import register_dataset
from .constants import JUDGE_PROMPT from .constants import JUDGE_PROMPT
# How to run this test: # How to run this test:
@ -28,6 +26,7 @@ from .constants import JUDGE_PROMPT
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
class Testeval: class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_benchmarks_list(self, eval_stack): async def test_benchmarks_list(self, eval_stack):
@ -68,7 +67,7 @@ class Testeval:
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=rows.rows, input_rows=rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
benchmark_config=AppBenchmarkConfig( benchmark_config=dict(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model=inference_model, model=inference_model,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
@ -111,7 +110,7 @@ class Testeval:
) )
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=AppBenchmarkConfig( benchmark_config=dict(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model=inference_model, model=inference_model,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
@ -169,7 +168,7 @@ class Testeval:
benchmark_id = "meta-reference-mmlu" benchmark_id = "meta-reference-mmlu"
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=BenchmarkBenchmarkConfig( benchmark_config=dict(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model=inference_model, model=inference_model,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View file

@ -26,6 +26,7 @@ from llama_stack.apis.post_training import (
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
class TestPostTraining: class TestPostTraining:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack): async def test_supervised_fine_tune(self, post_training_stack):

View file

@ -16,6 +16,7 @@ import pytest
from pytest import CollectReport from pytest import CollectReport
from termcolor import cprint from termcolor import cprint
from llama_stack.env import get_env_or_fail
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import ( from llama_stack.models.llama.sku_list import (
all_registered_models, all_registered_models,
@ -26,7 +27,6 @@ from llama_stack.models.llama.sku_list import (
safety_models, safety_models,
) )
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.tests.env import get_env_or_fail
from .metadata import API_MAPS from .metadata import API_MAPS

View file

@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..datasetio.test_datasetio import register_dataset
@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
def test_scoring_functions_list(llama_stack_client):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
response = llama_stack_client.scoring_functions.list()
assert isinstance(response, list)
assert len(response) > 0
def test_scoring_score(llama_stack_client):
register_dataset(llama_stack_client, for_rag=True)
response = llama_stack_client.datasets.list()
assert len(response) == 1
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = llama_stack_client.scoring_functions.list()
scoring_functions = {
scoring_fns_list[0].identifier: None,
}
response = llama_stack_client.scoring.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = llama_stack_client.scoring.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
save_results_dataset=False,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
register_dataset(llama_stack_client, for_rag=True)
response = llama_stack_client.datasets.list()
assert len(response) == 1
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_functions = {
"llm-as-judge::base": dict(
type="llm_as_judge",
judge_model=judge_model_id,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=[
"categorical_count",
],
)
}
response = llama_stack_client.scoring.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = llama_stack_client.scoring.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
save_results_dataset=False,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
register_dataset(llama_stack_client, for_rag=True)
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = llama_stack_client.scoring_functions.list()
scoring_functions = {}
aggr_fns = [
"accuracy",
"median",
"categorical_count",
"average",
]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge":
aggr_fns = ["categorical_count"]
scoring_functions[x.identifier] = dict(
type="llm_as_judge",
judge_model=judge_model_id,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif x.provider_id == "basic" or x.provider_id == "braintrust":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = dict(
type="regex_parser",
parsing_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = dict(
type="basic",
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = None
response = llama_stack_client.scoring.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
assert len(response.results[x].aggregated_results) == len(aggr_fns)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import pytest
@pytest.fixture
def sample_search_query():
return "What are the latest developments in quantum computing?"
@pytest.fixture
def sample_wolfram_alpha_query():
return "What is the square root of 16?"
def test_web_search_tool(llama_stack_client, sample_search_query):
"""Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
content = json.loads(response.content)
assert "query" in content
assert "top_k" in content
assert len(content["top_k"]) > 0
first = content["top_k"][0]
assert "title" in first
assert "url" in first
def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
"""Test the wolfram alpha tool functionality."""
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
print(response.content)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
content = json.loads(response.content)
result = content["queryresult"]
assert "success" in result
assert result["success"]
assert "pods" in result
assert len(result["pods"]) > 0

View file

@ -4,29 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import random
import pytest import pytest
from llama_stack_client.types import Document from llama_stack_client.types import Document
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def empty_vector_db_registry(llama_stack_client): def client_with_empty_registry(client_with_models):
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] def clear_registry():
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
for vector_db_id in vector_dbs: for vector_db_id in vector_dbs:
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id) client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
clear_registry()
yield client_with_models
@pytest.fixture(scope="function") # you must clean after the last test if you were running tests against
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): # a stateful server instance
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" clear_registry()
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -63,9 +57,15 @@ def assert_valid_response(response):
assert isinstance(chunk.content, str) assert isinstance(chunk.content, str)
def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vector_db_registry, sample_documents): def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
vector_db_id = single_entry_vector_db_registry[0] vector_db_id = "test_vector_db"
llama_stack_client.tool_runtime.rag_tool.insert( client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=sample_documents, documents=sample_documents,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
# Query with a direct match # Query with a direct match
query1 = "programming language" query1 = "programming language"
response1 = llama_stack_client.vector_io.query( response1 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query=query1, query=query1,
) )
@ -82,7 +82,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
# Query with semantic similarity # Query with semantic similarity
query2 = "AI and brain-inspired computing" query2 = "AI and brain-inspired computing"
response2 = llama_stack_client.vector_io.query( response2 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query=query2, query=query2,
) )
@ -91,7 +91,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
# Query with limit on number of results (max_chunks=2) # Query with limit on number of results (max_chunks=2)
query3 = "computer" query3 = "computer"
response3 = llama_stack_client.vector_io.query( response3 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query=query3, query=query3,
params={"max_chunks": 2}, params={"max_chunks": 2},
@ -101,7 +101,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
# Query with threshold on similarity score # Query with threshold on similarity score
query4 = "computer" query4 = "computer"
response4 = llama_stack_client.vector_io.query( response4 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query=query4, query=query4,
params={"score_threshold": 0.01}, params={"score_threshold": 0.01},
@ -110,20 +110,20 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
assert all(score >= 0.01 for score in response4.scores) assert all(score >= 0.01 for score in response4.scores)
def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db_registry): def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"] providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0 assert len(providers) > 0
vector_db_id = "test_vector_db" vector_db_id = "test_vector_db"
llama_stack_client.vector_dbs.register( client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model=embedding_model_id,
embedding_dimension=384, embedding_dimension=384,
) )
# list to check memory bank is successfully registered # list to check memory bank is successfully registered
available_vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs assert vector_db_id in available_vector_dbs
# URLs of documents to insert # URLs of documents to insert
@ -144,14 +144,14 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
llama_stack_client.tool_runtime.rag_tool.insert( client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
# Query for the name of method # Query for the name of method
response1 = llama_stack_client.vector_io.query( response1 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query="What's the name of the fine-tunning method used?", query="What's the name of the fine-tunning method used?",
) )
@ -159,7 +159,7 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
assert any("lora" in chunk.content.lower() for chunk in response1.chunks) assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
# Query for the name of model # Query for the name of model
response2 = llama_stack_client.vector_io.query( response2 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query="Which Llama model is mentioned?", query="Which Llama model is mentioned?",
) )