diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8d52ac1b9..8075ea2bd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -531,11 +531,6 @@ class ChatAgent(ShieldRunnerMixin): log.info(f"{str(message)}") tool_call = message.tool_calls[0] - name = tool_call.tool_name - if not isinstance(name, BuiltinTool) or name not in enabled_tools: - yield message - return - step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -691,10 +686,8 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name - assert isinstance(name, BuiltinTool) - - name = name.value - + if isinstance(name, BuiltinTool): + name = name.value result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py index f80d10dfe..94a387f30 100644 --- a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -54,7 +54,6 @@ class TavilySearchToolRuntimeImpl( "https://api.tavily.com/search", json={"api_key": api_key, "query": args["query"]}, ) - print(f"================= Tavily response: {response.json()}") return ToolInvocationResult( content=json.dumps(self._clean_tavily_response(response.json())) diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index cd4f75418..147f04b02 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool( tool_execution = tool_execution_events[0].event.payload.step_details assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == tool_name + actual_tool_name = tool_execution.tool_calls[0].tool_name + if isinstance(actual_tool_name, BuiltinTool): + actual_tool_name = actual_tool_name.value + assert actual_tool_name == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index ed0cabe1c..d296105e0 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union import httpx from llama_models.datatypes import is_multimodal, ModelFamily - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ( RawContent, @@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) - from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, @@ -52,7 +50,6 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) - from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__)