forked from phoenix-oss/llama-stack-mirror
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
127 lines
4 KiB
Python
127 lines
4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.inference import UserMessage
|
|
from llama_stack.apis.memory import MemoryBankDocument
|
|
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
|
from llama_stack.apis.tools import ToolInvocationResult
|
|
from llama_stack.providers.datatypes import Api
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_search_query():
|
|
return "What are the latest developments in quantum computing?"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_wolfram_alpha_query():
|
|
return "What is the square root of 16?"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_documents():
|
|
urls = [
|
|
"memory_optimizations.rst",
|
|
"chat.rst",
|
|
"llama3.rst",
|
|
"datasets.rst",
|
|
"qat_finetune.rst",
|
|
"lora_finetune.rst",
|
|
]
|
|
return [
|
|
MemoryBankDocument(
|
|
document_id=f"num-{i}",
|
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
mime_type="text/plain",
|
|
metadata={},
|
|
)
|
|
for i, url in enumerate(urls)
|
|
]
|
|
|
|
|
|
class TestTools:
|
|
@pytest.mark.asyncio
|
|
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
|
"""Test the web search tool functionality."""
|
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
|
|
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
|
|
|
# Execute the tool
|
|
response = await tools_impl.invoke_tool(
|
|
tool_name="web_search", args={"query": sample_search_query}
|
|
)
|
|
|
|
# Verify the response
|
|
assert isinstance(response, ToolInvocationResult)
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|
|
assert isinstance(response.content, str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
|
"""Test the wolfram alpha tool functionality."""
|
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
|
|
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
|
|
|
response = await tools_impl.invoke_tool(
|
|
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
|
)
|
|
|
|
# Verify the response
|
|
assert isinstance(response, ToolInvocationResult)
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|
|
assert isinstance(response.content, str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_tool(self, tools_stack, sample_documents):
|
|
"""Test the memory tool functionality."""
|
|
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
|
memory_impl = tools_stack.impls[Api.memory]
|
|
tools_impl = tools_stack.impls[Api.tool_runtime]
|
|
|
|
# Register memory bank
|
|
await memory_banks_impl.register_memory_bank(
|
|
memory_bank_id="test_bank",
|
|
params=VectorMemoryBankParams(
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=512,
|
|
overlap_size_in_tokens=64,
|
|
),
|
|
provider_id="faiss",
|
|
)
|
|
|
|
# Insert documents into memory
|
|
await memory_impl.insert_documents(
|
|
bank_id="test_bank",
|
|
documents=sample_documents,
|
|
)
|
|
|
|
# Execute the memory tool
|
|
response = await tools_impl.invoke_tool(
|
|
tool_name="memory",
|
|
args={
|
|
"messages": [
|
|
UserMessage(
|
|
content="What are the main topics covered in the documentation?",
|
|
)
|
|
],
|
|
"memory_bank_ids": ["test_bank"],
|
|
},
|
|
)
|
|
|
|
# Verify the response
|
|
assert isinstance(response, ToolInvocationResult)
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|