mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add tool tests
This commit is contained in:
parent
18d9937500
commit
50852cadf3
6 changed files with 309 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -19,3 +19,4 @@ Package.resolved
|
||||||
_build
|
_build
|
||||||
docs/src
|
docs/src
|
||||||
pyrightconfig.json
|
pyrightconfig.json
|
||||||
|
.aider*
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
5
llama_stack/providers/tests/tools/__init__.py
Normal file
5
llama_stack/providers/tests/tools/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
65
llama_stack/providers/tests/tools/conftest.py
Normal file
65
llama_stack/providers/tests/tools/conftest.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
from ..safety.fixtures import SAFETY_FIXTURES
|
||||||
|
from .fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
|
"tool_runtime": "memory_and_search",
|
||||||
|
},
|
||||||
|
id="together",
|
||||||
|
marks=pytest.mark.together,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for mark in ["together"]:
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--safety-shield",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-Guard-3-1B",
|
||||||
|
help="Specify the safety shield to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "tools_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"safety": SAFETY_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
print(combinations)
|
||||||
|
metafunc.parametrize("tools_stack", combinations, indirect=True)
|
138
llama_stack/providers/tests/tools/fixtures.py
Normal file
138
llama_stack/providers/tests/tools/fixtures.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
BuiltInToolDef,
|
||||||
|
CustomToolDef,
|
||||||
|
ToolGroupInput,
|
||||||
|
ToolParameter,
|
||||||
|
UserDefinedToolGroupDef,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
provider_type="inline::memory-runtime",
|
||||||
|
config={},
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="tavily-search",
|
||||||
|
provider_type="remote::tavily-search",
|
||||||
|
config={
|
||||||
|
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def tools_stack(request, inference_model, safety_shield):
|
||||||
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory", "tools", "tool_runtime"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if key == "inference":
|
||||||
|
providers[key].append(
|
||||||
|
Provider(
|
||||||
|
provider_id="tools_memory_provider",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
inference_models = (
|
||||||
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
|
)
|
||||||
|
models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=model,
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id=providers["inference"][0].provider_id,
|
||||||
|
)
|
||||||
|
for model in inference_models
|
||||||
|
]
|
||||||
|
models.append(
|
||||||
|
ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
provider_id="tools_memory_provider",
|
||||||
|
metadata={"embedding_dimension": 384},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
tool_group_id="tavily_search_group",
|
||||||
|
tool_group=UserDefinedToolGroupDef(
|
||||||
|
tools=[
|
||||||
|
BuiltInToolDef(
|
||||||
|
name="brave_search",
|
||||||
|
description="Search the web using Brave Search",
|
||||||
|
metadata={},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
tool_group_id="memory_group",
|
||||||
|
tool_group=UserDefinedToolGroupDef(
|
||||||
|
tools=[
|
||||||
|
CustomToolDef(
|
||||||
|
name="memory",
|
||||||
|
description="Query the memory bank",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for in memory",
|
||||||
|
parameter_type="string",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="memory_bank_id",
|
||||||
|
description="The ID of the memory bank to search",
|
||||||
|
parameter_type="string",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_stack = await construct_stack_for_test(
|
||||||
|
[Api.tools, Api.inference, Api.memory],
|
||||||
|
providers,
|
||||||
|
provider_data,
|
||||||
|
models=models,
|
||||||
|
tool_groups=tool_groups,
|
||||||
|
)
|
||||||
|
return test_stack
|
99
llama_stack/providers/tests/tools/test_tools.py
Normal file
99
llama_stack/providers/tests/tools/test_tools.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import MemoryBankDocument
|
||||||
|
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||||
|
from llama_stack.apis.tools import ToolInvocationResult
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_query():
|
||||||
|
return "What are the latest developments in quantum computing?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_documents():
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_brave_search_tool(self, tools_stack, sample_search_query):
|
||||||
|
"""Test the Brave search tool functionality."""
|
||||||
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||||
|
|
||||||
|
# Execute the tool
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="brave_search", tool_args={"query": sample_search_query}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||||
|
"""Test the memory tool functionality."""
|
||||||
|
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
||||||
|
memory_impl = tools_stack.impls[Api.memory]
|
||||||
|
tools_impl = tools_stack.impls[Api.tools]
|
||||||
|
|
||||||
|
# Register memory bank
|
||||||
|
await memory_banks_impl.register_memory_bank(
|
||||||
|
memory_bank_id="test_memory_bank",
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert documents into memory
|
||||||
|
memory_impl.insert_documents(
|
||||||
|
bank_id="test_memory_bank",
|
||||||
|
documents=sample_documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the memory tool
|
||||||
|
response = await tools_impl.invoke_tool(
|
||||||
|
tool_name="memory",
|
||||||
|
tool_args={
|
||||||
|
"query": "What are the main topics covered in the documentation?",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, ToolInvocationResult)
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert isinstance(response.content, str)
|
Loading…
Add table
Add a link
Reference in a new issue