mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
fix agents to run custom tools
This commit is contained in:
parent
9192a9bbb4
commit
2ad67529ef
4 changed files with 6 additions and 14 deletions
|
@ -531,11 +531,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
|
|
||||||
name = tool_call.tool_name
|
|
||||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -691,10 +686,8 @@ async def execute_tool_call_maybe(
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
assert isinstance(name, BuiltinTool)
|
if isinstance(name, BuiltinTool):
|
||||||
|
name = name.value
|
||||||
name = name.value
|
|
||||||
|
|
||||||
result = await tool_runtime_api.invoke_tool(
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
args=dict(
|
args=dict(
|
||||||
|
|
|
@ -54,7 +54,6 @@ class TavilySearchToolRuntimeImpl(
|
||||||
"https://api.tavily.com/search",
|
"https://api.tavily.com/search",
|
||||||
json={"api_key": api_key, "query": args["query"]},
|
json={"api_key": api_key, "query": args["query"]},
|
||||||
)
|
)
|
||||||
print(f"================= Tavily response: {response.json()}")
|
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=json.dumps(self._clean_tavily_response(response.json()))
|
content=json.dumps(self._clean_tavily_response(response.json()))
|
||||||
|
|
|
@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool(
|
||||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
assert len(tool_execution.tool_calls) > 0
|
assert len(tool_execution.tool_calls) > 0
|
||||||
assert tool_execution.tool_calls[0].tool_name == tool_name
|
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||||
|
if isinstance(actual_tool_name, BuiltinTool):
|
||||||
|
actual_tool_name = actual_tool_name.value
|
||||||
|
assert actual_tool_name == tool_name
|
||||||
assert len(tool_execution.tool_responses) > 0
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
|
|
@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
|
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue