Support Tavily as built-in search tool. (#485)

# What does this PR do?

Add Tavily as a built-in search tool, in addition to Brave and Bing.

## Test Plan

It's tested using ollama remote, showing parity to the Brave search
tool.
- Install and run ollama with `ollama run llama3.1:8b-instruct-fp16`
- Build ollama distribution `llama stack build --template ollama
--image-type conda`
- Run ollama `stack run
/$USER/.llama/distributions/llamastack-ollama/ollama-run.yaml --port
5001`
- Client test command: `python - m
agents.test_agents.TestAgents.test_create_agent_turn_with_tavily_search`,
with enviroments:

MASTER_ADDR=0.0.0.0;MASTER_PORT=5001;RANK=0;REMOTE_STACK_HOST=0.0.0.0;REMOTE_STACK_PORT=5001;TAVILY_SEARCH_API_KEY=tvly-<YOUR-KEY>;WORLD_SIZE=1

Test passes on the specific case (ollama remote).

Server output: 
```
Listening on ['::', '0.0.0.0']:5001
INFO:     Started server process [7220]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://['::', '0.0.0.0']:5001 (Press CTRL+C to quit)
INFO:     127.0.0.1:65209 - "POST /agents/create HTTP/1.1" 200 OK
INFO:     127.0.0.1:65210 - "POST /agents/session/create HTTP/1.1" 200 OK
INFO:     127.0.0.1:65211 - "POST /agents/turn/create HTTP/1.1" 200 OK
role='user' content='What are the latest developments in quantum computing?' context=None
role='assistant' content='' stop_reason=<StopReason.end_of_turn: 'end_of_turn'> tool_calls=[ToolCall(call_id='fc92ccb8-1039-4ce8-ba5e-8f2b0147661c', tool_name=<BuiltinTool.brave_search: 'brave_search'>, arguments={'query': 'latest developments in quantum computing'})]
role='ipython' call_id='fc92ccb8-1039-4ce8-ba5e-8f2b0147661c' tool_name=<BuiltinTool.brave_search: 'brave_search'> content='{"query": "latest developments in quantum computing", "top_k": [{"title": "IBM Unveils 400 Qubit-Plus Quantum Processor and Next-Generation IBM ...", "url": "https://newsroom.ibm.com/2022-11-09-IBM-Unveils-400-Qubit-Plus-Quantum-Processor-and-Next-Generation-IBM-Quantum-System-Two", "content": "This system is targeted to be online by the end of 2023 and will be a building b...<more>...onnect large-scale ...", "url": "https://news.mit.edu/2023/quantum-interconnects-photon-emission-0105", "content": "Quantum computers hold the promise of performing certain tasks that are intractable even on the world\'s most powerful supercomputers. In the future, scientists anticipate using quantum computing to emulate materials systems, simulate quantum chemistry, and optimize hard tasks, with impacts potentially spanning finance to pharmaceuticals.", "score": 0.71721, "raw_content": null}]}'
Assistant: The latest developments in quantum computing include:

* IBM unveiling its 400 qubit-plus quantum processor and next-generation IBM Quantum System Two, which will be a building block of quantum-centric supercomputing.
* The development of utility-scale quantum computing, which can serve as a scientific tool to explore utility-scale classes of problems in chemistry, physics, and materials beyond brute force classical simulation of quantum mechanics.
* The introduction of advanced hardware across IBM's global fleet of 100+ qubit systems, as well as easy-to-use software that users and computational scientists can now obtain reliable results from quantum systems as they map increasingly larger and more complex problems to quantum circuits.
* Research on quantum repeaters, which use defects in diamond to interconnect quantum systems and could provide the foundation for scalable quantum networking.
* The development of a new source of quantum light, which could be used to improve the efficiency of quantum computers.
* The creation of a new mathematical "blueprint" that is accelerating fusion device development using Dyson maps.
* Research on canceling noise to improve quantum devices, with MIT researchers developing a protocol to extend the life of quantum coherence.
```

Verified with tool response. The final model response is updated with
the search requests.

## Sources

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.

Co-authored-by: Martin Yuan <myuan@meta.com>
This commit is contained in:
Mengtao Yuan 2024-11-19 20:59:02 -08:00 committed by GitHub
parent 08be023290
commit 1086b500f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 106 additions and 49 deletions

View file

@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
tavily = "tavily"
@json_schema_type

View file

@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
@ -257,6 +260,21 @@ class BraveSearch:
return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key

View file

@ -68,6 +68,73 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
search_tool_definition: SearchToolDefinition,
) -> None:
"""
Create an agent turn with a search tool.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
"""
# Create an agent with the search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [search_tool_definition],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
@ -215,63 +282,34 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params
):
agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
# Create an agent with Brave search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [
SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
],
}
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
agent_id, session_id = await create_agent_session(agents_impl, agent_config)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value, # place holder only
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
engine=SearchEngineType.tavily,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]