add tool tests

This commit is contained in:
Dinesh Yeduguru 2024-12-30 10:50:59 -08:00
parent 18d9937500
commit 50852cadf3
6 changed files with 309 additions and 0 deletions

1
.gitignore vendored
View file

@ -19,3 +19,4 @@ Package.resolved
_build
docs/src
pyrightconfig.json
.aider*

View file

@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from .fixtures import TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "llama_guard",
"memory": "faiss",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
),
]
def pytest_configure(config):
for mark in ["together"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)
def pytest_generate_tests(metafunc):
if "tools_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
print(combinations)
metafunc.parametrize("tools_stack", combinations, indirect=True)

View file

@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
BuiltInToolDef,
CustomToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="memory-runtime",
provider_type="inline::memory-runtime",
config={},
),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
],
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(request, inference_model, safety_shield):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory", "tools", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="tools_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="tools_memory_provider",
metadata={"embedding_dimension": 384},
)
)
tool_groups = [
ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
name="brave_search",
description="Search the web using Brave Search",
metadata={},
),
],
),
provider_id="tavily-search",
),
ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
CustomToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="query",
description="The query to search for in memory",
parameter_type="string",
required=True,
),
ToolParameter(
name="memory_bank_id",
description="The ID of the memory bank to search",
parameter_type="string",
required=True,
),
],
metadata={},
)
],
),
provider_id="memory-runtime",
),
]
test_stack = await construct_stack_for_test(
[Api.tools, Api.inference, Api.memory],
providers,
provider_data,
models=models,
tool_groups=tool_groups,
)
return test_stack

View file

@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.datatypes import Api
@pytest.fixture
def sample_search_query():
return "What are the latest developments in quantum computing?"
@pytest.fixture
def sample_documents():
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
return [
MemoryBankDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
class TestTools:
@pytest.mark.asyncio
async def test_brave_search_tool(self, tools_stack, sample_search_query):
"""Test the Brave search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="brave_search", tool_args={"query": sample_search_query}
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory]
tools_impl = tools_stack.impls[Api.tools]
# Register memory bank
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_memory_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss",
)
# Insert documents into memory
memory_impl.insert_documents(
bank_id="test_memory_bank",
documents=sample_documents,
)
# Execute the memory tool
response = await tools_impl.invoke_tool(
tool_name="memory",
tool_args={
"query": "What are the main topics covered in the documentation?",
},
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)