Support Tavily as built-in search tool.

This commit is contained in:
Martin Yuan 2024-11-18 17:59:50 -08:00
parent e8112b31ab
commit 9b76224c28
3 changed files with 106 additions and 49 deletions

View file

@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
class SearchEngineType(Enum): class SearchEngineType(Enum):
bing = "bing" bing = "bing"
brave = "brave" brave = "brave"
tavily = "tavily"
@json_schema_type @json_schema_type

View file

@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
class SearchTool(SingleMessageBuiltinTool): class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing: if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs) self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave: elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs) self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else: else:
raise ValueError(f"Unknown search engine: {engine}") raise ValueError(f"Unknown search engine: {engine}")
@ -257,6 +260,21 @@ class BraveSearch:
return {"query": query, "top_k": clean_response} return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool): class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key

View file

@ -68,6 +68,73 @@ def query_attachment_messages():
] ]
async def create_agent_turn_with_search_tool(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
search_tool_definition: SearchToolDefinition,
) -> None:
"""
Create an agent turn with a search tool.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
"""
# Create an agent with the search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [search_tool_definition],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
class TestAgents: class TestAgents:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_turns_with_safety( async def test_agent_turns_with_safety(
@ -215,63 +282,34 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search( async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params self, agents_stack, search_query_messages, common_params
): ):
agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ: if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
# Create an agent with Brave search tool search_tool_definition = SearchToolDefinition(
agent_config = AgentConfig( type=AgentTool.brave_search.value,
**{ api_key=os.environ["BRAVE_SEARCH_API_KEY"],
**common_params, engine=SearchEngineType.brave,
"tools": [ )
SearchToolDefinition( await create_agent_turn_with_search_tool(
type=AgentTool.brave_search.value, agents_stack, search_query_messages, common_params, search_tool_definition
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
],
}
) )
agent_id, session_id = await create_agent_session(agents_impl, agent_config) @pytest.mark.asyncio
turn_request = dict( async def test_create_agent_turn_with_tavily_search(
agent_id=agent_id, self, agents_stack, search_query_messages, common_params
session_id=session_id, ):
messages=search_query_messages, if "TAVILY_SEARCH_API_KEY" not in os.environ:
stream=True, pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value, # place holder only
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
engine=SearchEngineType.tavily,
) )
await create_agent_turn_with_search_tool(
turn_response = [ agents_stack, search_query_messages, common_params, search_tool_definition
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
) )
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response): def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response] event_types = [chunk.event.payload.event_type for chunk in turn_response]