fix agents to run custom tools

This commit is contained in:
Dinesh Yeduguru 2024-12-20 22:02:00 -08:00
parent 9192a9bbb4
commit 2ad67529ef
4 changed files with 6 additions and 14 deletions

View file

@ -531,11 +531,6 @@ class ChatAgent(ShieldRunnerMixin):
log.info(f"{str(message)}") log.info(f"{str(message)}")
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
yield message
return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -691,10 +686,8 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
assert isinstance(name, BuiltinTool) if isinstance(name, BuiltinTool):
name = name.value
name = name.value
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
args=dict( args=dict(

View file

@ -54,7 +54,6 @@ class TavilySearchToolRuntimeImpl(
"https://api.tavily.com/search", "https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]}, json={"api_key": api_key, "query": args["query"]},
) )
print(f"================= Tavily response: {response.json()}")
return ToolInvocationResult( return ToolInvocationResult(
content=json.dumps(self._clean_tavily_response(response.json())) content=json.dumps(self._clean_tavily_response(response.json()))

View file

@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool(
tool_execution = tool_execution_events[0].event.payload.step_details tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep) assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0 assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == tool_name actual_tool_name = tool_execution.tool_calls[0].tool_name
if isinstance(actual_tool_name, BuiltinTool):
actual_tool_name = actual_tool_name.value
assert actual_tool_name == tool_name
assert len(tool_execution.tool_responses) > 0 assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages) check_turn_complete_event(turn_response, session_id, search_query_messages)

View file

@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
import httpx import httpx
from llama_models.datatypes import is_multimodal, ModelFamily from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
RawContent, RawContent,
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__) log = logging.getLogger(__name__)