mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Support Tavily as built-in search tool.
This commit is contained in:
parent
e8112b31ab
commit
9b76224c28
3 changed files with 106 additions and 49 deletions
|
@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
|
||||||
class SearchEngineType(Enum):
|
class SearchEngineType(Enum):
|
||||||
bing = "bing"
|
bing = "bing"
|
||||||
brave = "brave"
|
brave = "brave"
|
||||||
|
tavily = "tavily"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
||||||
class SearchTool(SingleMessageBuiltinTool):
|
class SearchTool(SingleMessageBuiltinTool):
|
||||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.engine_type = engine
|
||||||
if engine == SearchEngineType.bing:
|
if engine == SearchEngineType.bing:
|
||||||
self.engine = BingSearch(api_key, **kwargs)
|
self.engine = BingSearch(api_key, **kwargs)
|
||||||
elif engine == SearchEngineType.brave:
|
elif engine == SearchEngineType.brave:
|
||||||
self.engine = BraveSearch(api_key, **kwargs)
|
self.engine = BraveSearch(api_key, **kwargs)
|
||||||
|
elif engine == SearchEngineType.tavily:
|
||||||
|
self.engine = TavilySearch(api_key, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown search engine: {engine}")
|
raise ValueError(f"Unknown search engine: {engine}")
|
||||||
|
|
||||||
|
@ -257,6 +260,21 @@ class BraveSearch:
|
||||||
return {"query": query, "top_k": clean_response}
|
return {"query": query, "top_k": clean_response}
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearch:
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
async def search(self, query: str) -> str:
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={"api_key": self.api_key, "query": query},
|
||||||
|
)
|
||||||
|
return json.dumps(self._clean_tavily_response(response.json()))
|
||||||
|
|
||||||
|
def _clean_tavily_response(self, search_response, top_k=3):
|
||||||
|
return {"query": search_response["query"], "top_k": search_response["results"]}
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
|
@ -68,6 +68,73 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def create_agent_turn_with_search_tool(
|
||||||
|
agents_stack: Dict[str, object],
|
||||||
|
search_query_messages: List[object],
|
||||||
|
common_params: Dict[str, str],
|
||||||
|
search_tool_definition: SearchToolDefinition,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create an agent turn with a search tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents_stack (Dict[str, object]): The agents stack.
|
||||||
|
search_query_messages (List[object]): The search query messages.
|
||||||
|
common_params (Dict[str, str]): The common parameters.
|
||||||
|
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create an agent with the search tool
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
**{
|
||||||
|
**common_params,
|
||||||
|
"tools": [search_tool_definition],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_stack.impls[Api.agents], agent_config
|
||||||
|
)
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=search_query_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk
|
||||||
|
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||||
|
**turn_request
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
|
)
|
||||||
|
|
||||||
|
check_event_types(turn_response)
|
||||||
|
|
||||||
|
# Check for tool execution events
|
||||||
|
tool_execution_events = [
|
||||||
|
chunk
|
||||||
|
for chunk in turn_response
|
||||||
|
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||||
|
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||||
|
]
|
||||||
|
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||||
|
|
||||||
|
# Check the tool execution details
|
||||||
|
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||||
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
|
assert len(tool_execution.tool_calls) > 0
|
||||||
|
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||||
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(
|
async def test_agent_turns_with_safety(
|
||||||
|
@ -215,63 +282,34 @@ class TestAgents:
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_brave_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
):
|
):
|
||||||
agents_impl = agents_stack.impls[Api.agents]
|
|
||||||
|
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
# Create an agent with Brave search tool
|
search_tool_definition = SearchToolDefinition(
|
||||||
agent_config = AgentConfig(
|
type=AgentTool.brave_search.value,
|
||||||
**{
|
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||||
**common_params,
|
engine=SearchEngineType.brave,
|
||||||
"tools": [
|
)
|
||||||
SearchToolDefinition(
|
await create_agent_turn_with_search_tool(
|
||||||
type=AgentTool.brave_search.value,
|
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
|
||||||
engine=SearchEngineType.brave,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_id, session_id = await create_agent_session(agents_impl, agent_config)
|
@pytest.mark.asyncio
|
||||||
turn_request = dict(
|
async def test_create_agent_turn_with_tavily_search(
|
||||||
agent_id=agent_id,
|
self, agents_stack, search_query_messages, common_params
|
||||||
session_id=session_id,
|
):
|
||||||
messages=search_query_messages,
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
stream=True,
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
search_tool_definition = SearchToolDefinition(
|
||||||
|
type=AgentTool.brave_search.value, # place holder only
|
||||||
|
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
|
engine=SearchEngineType.tavily,
|
||||||
)
|
)
|
||||||
|
await create_agent_turn_with_search_tool(
|
||||||
turn_response = [
|
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
|
||||||
assert all(
|
|
||||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
|
||||||
)
|
)
|
||||||
|
|
||||||
check_event_types(turn_response)
|
|
||||||
|
|
||||||
# Check for tool execution events
|
|
||||||
tool_execution_events = [
|
|
||||||
chunk
|
|
||||||
for chunk in turn_response
|
|
||||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
|
||||||
and chunk.event.payload.step_details.step_type
|
|
||||||
== StepType.tool_execution.value
|
|
||||||
]
|
|
||||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
|
||||||
|
|
||||||
# Check the tool execution details
|
|
||||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
|
||||||
assert len(tool_execution.tool_calls) > 0
|
|
||||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
|
||||||
assert len(tool_execution.tool_responses) > 0
|
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
|
||||||
|
|
||||||
|
|
||||||
def check_event_types(turn_response):
|
def check_event_types(turn_response):
|
||||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue