fix agents to run custom tools

This commit is contained in:
Dinesh Yeduguru 2024-12-20 22:02:00 -08:00
parent 9192a9bbb4
commit 2ad67529ef
4 changed files with 6 additions and 14 deletions

View file

@ -531,11 +531,6 @@ class ChatAgent(ShieldRunnerMixin):
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
yield message
return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -691,10 +686,8 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
assert isinstance(name, BuiltinTool)
name = name.value
if isinstance(name, BuiltinTool):
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(

View file

@ -54,7 +54,6 @@ class TavilySearchToolRuntimeImpl(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]},
)
print(f"================= Tavily response: {response.json()}")
return ToolInvocationResult(
content=json.dumps(self._clean_tavily_response(response.json()))

View file

@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool(
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == tool_name
actual_tool_name = tool_execution.tool_calls[0].tool_name
if isinstance(actual_tool_name, BuiltinTool):
actual_tool_name = actual_tool_name.value
assert actual_tool_name == tool_name
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)

View file

@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
RawContent,
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)