mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
fix agents to run custom tools
This commit is contained in:
parent
9192a9bbb4
commit
2ad67529ef
4 changed files with 6 additions and 14 deletions
|
@ -531,11 +531,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -691,10 +686,8 @@ async def execute_tool_call_maybe(
|
|||
|
||||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
assert isinstance(name, BuiltinTool)
|
||||
|
||||
name = name.value
|
||||
|
||||
if isinstance(name, BuiltinTool):
|
||||
name = name.value
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
|
|
|
@ -54,7 +54,6 @@ class TavilySearchToolRuntimeImpl(
|
|||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": args["query"]},
|
||||
)
|
||||
print(f"================= Tavily response: {response.json()}")
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_tavily_response(response.json()))
|
||||
|
|
|
@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool(
|
|||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == tool_name
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
if isinstance(actual_tool_name, BuiltinTool):
|
||||
actual_tool_name = actual_tool_name.value
|
||||
assert actual_tool_name == tool_name
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
|
|
@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import httpx
|
||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawContent,
|
||||
|
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
|
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue