passing tool tests

This commit is contained in:
Dinesh Yeduguru 2024-12-30 11:37:28 -08:00
parent 50852cadf3
commit b7ae86ae03
2 changed files with 29 additions and 23 deletions

View file

@ -17,6 +17,7 @@ from llama_stack.apis.tools import (
ToolParameter, ToolParameter,
UserDefinedToolGroupDef, UserDefinedToolGroupDef,
) )
from llama_stack.apis.tools.tools import BuiltinTool
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -47,12 +48,12 @@ TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def tools_stack(request, inference_model, safety_shield): async def tools_stack(request, inference_model):
fixture_dict = request.param fixture_dict = request.param
providers = {} providers = {}
provider_data = {} provider_data = {}
for key in ["inference", "memory", "tools", "tool_runtime"]: for key in ["inference", "memory", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers providers[key] = fixture.providers
if key == "inference": if key == "inference":
@ -91,8 +92,7 @@ async def tools_stack(request, inference_model, safety_shield):
tool_group=UserDefinedToolGroupDef( tool_group=UserDefinedToolGroupDef(
tools=[ tools=[
BuiltInToolDef( BuiltInToolDef(
name="brave_search", built_in_type=BuiltinTool.brave_search,
description="Search the web using Brave Search",
metadata={}, metadata={},
), ),
], ],
@ -108,19 +108,19 @@ async def tools_stack(request, inference_model, safety_shield):
description="Query the memory bank", description="Query the memory bank",
parameters=[ parameters=[
ToolParameter( ToolParameter(
name="query", name="input_messages",
description="The query to search for in memory", description="The input messages to search for in memory",
parameter_type="string", parameter_type="list",
required=True,
),
ToolParameter(
name="memory_bank_id",
description="The ID of the memory bank to search",
parameter_type="string",
required=True, required=True,
), ),
], ],
metadata={}, metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
) )
], ],
), ),
@ -129,7 +129,7 @@ async def tools_stack(request, inference_model, safety_shield):
] ]
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.tools, Api.inference, Api.memory], [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
providers, providers,
provider_data, provider_data,
models=models, models=models,

View file

@ -8,11 +8,14 @@ import os
import pytest import pytest
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.memory import MemoryBankDocument from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .fixtures import tool_runtime_memory_and_search # noqa: F401
@pytest.fixture @pytest.fixture
def sample_search_query(): def sample_search_query():
@ -51,7 +54,7 @@ class TestTools:
# Execute the tool # Execute the tool
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="brave_search", tool_args={"query": sample_search_query} tool_name="brave_search", args={"query": sample_search_query}
) )
# Verify the response # Verify the response
@ -65,11 +68,11 @@ class TestTools:
"""Test the memory tool functionality.""" """Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks] memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory] memory_impl = tools_stack.impls[Api.memory]
tools_impl = tools_stack.impls[Api.tools] tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank # Register memory bank
await memory_banks_impl.register_memory_bank( await memory_banks_impl.register_memory_bank(
memory_bank_id="test_memory_bank", memory_bank_id="test_bank",
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
@ -79,16 +82,20 @@ class TestTools:
) )
# Insert documents into memory # Insert documents into memory
memory_impl.insert_documents( await memory_impl.insert_documents(
bank_id="test_memory_bank", bank_id="test_bank",
documents=sample_documents, documents=sample_documents,
) )
# Execute the memory tool # Execute the memory tool
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="memory", tool_name="memory",
tool_args={ args={
"query": "What are the main topics covered in the documentation?", "input_messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
}, },
) )
@ -96,4 +103,3 @@ class TestTools:
assert isinstance(response, ToolInvocationResult) assert isinstance(response, ToolInvocationResult)
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0
assert isinstance(response.content, str)