add tavily

This commit is contained in:
Dinesh Yeduguru 2024-12-20 21:01:41 -08:00
parent dcdf9da6ef
commit 9192a9bbb4
6 changed files with 163 additions and 2 deletions

View file

@ -77,6 +77,13 @@ def tool_runtime_memory() -> ProviderFixture:
"api_key": os.environ["BRAVE_SEARCH_API_KEY"],
},
),
Provider(
provider_id="tavily-search",
provider_type="inline::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
],
)
@ -146,13 +153,41 @@ async def agents_stack(request, inference_model, safety_shield):
ToolDef(
name="brave_search",
description="brave_search",
parameters=[],
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
metadata={},
),
],
),
provider_id="brave-search",
),
ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="tavily_search",
description="tavily_search",
parameters=[
ToolParameter(
name="query",
description="query",
parameter_type="string",
required=True,
),
],
metadata={},
),
],
),
provider_id="tavily-search",
),
ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(

View file

@ -149,7 +149,7 @@ async def create_agent_turn_with_search_tool(
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert tool_execution.tool_calls[0].tool_name == tool_name
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
@ -302,6 +302,20 @@ class TestAgents:
"brave_search",
)
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,
search_query_messages,
common_params,
"tavily_search",
)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]