passing tool tests

This commit is contained in:
Dinesh Yeduguru 2024-12-30 11:37:28 -08:00
parent 50852cadf3
commit b7ae86ae03
2 changed files with 29 additions and 23 deletions

View file

@ -17,6 +17,7 @@ from llama_stack.apis.tools import (
ToolParameter,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools.tools import BuiltinTool
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,12 +48,12 @@ TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(request, inference_model, safety_shield):
async def tools_stack(request, inference_model):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory", "tools", "tool_runtime"]:
for key in ["inference", "memory", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -91,8 +92,7 @@ async def tools_stack(request, inference_model, safety_shield):
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
name="brave_search",
description="Search the web using Brave Search",
built_in_type=BuiltinTool.brave_search,
metadata={},
),
],
@ -108,19 +108,19 @@ async def tools_stack(request, inference_model, safety_shield):
description="Query the memory bank",
parameters=[
ToolParameter(
name="query",
description="The query to search for in memory",
parameter_type="string",
required=True,
),
ToolParameter(
name="memory_bank_id",
description="The ID of the memory bank to search",
parameter_type="string",
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={},
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
@ -129,7 +129,7 @@ async def tools_stack(request, inference_model, safety_shield):
]
test_stack = await construct_stack_for_test(
[Api.tools, Api.inference, Api.memory],
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,

View file

@ -8,11 +8,14 @@ import os
import pytest
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.datatypes import Api
from .fixtures import tool_runtime_memory_and_search # noqa: F401
@pytest.fixture
def sample_search_query():
@ -51,7 +54,7 @@ class TestTools:
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="brave_search", tool_args={"query": sample_search_query}
tool_name="brave_search", args={"query": sample_search_query}
)
# Verify the response
@ -65,11 +68,11 @@ class TestTools:
"""Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory]
tools_impl = tools_stack.impls[Api.tools]
tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_memory_bank",
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
@ -79,16 +82,20 @@ class TestTools:
)
# Insert documents into memory
memory_impl.insert_documents(
bank_id="test_memory_bank",
await memory_impl.insert_documents(
bank_id="test_bank",
documents=sample_documents,
)
# Execute the memory tool
response = await tools_impl.invoke_tool(
tool_name="memory",
tool_args={
"query": "What are the main topics covered in the documentation?",
args={
"input_messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
},
)
@ -96,4 +103,3 @@ class TestTools:
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)