mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:12:24 +00:00
address feedback
This commit is contained in:
parent
8bf3f8ea56
commit
ac46bd5eb4
8 changed files with 24 additions and 35 deletions
|
|
@ -429,9 +429,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
n_iter = 0
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
custom_tools = {}
|
||||
for tool in self.agent_config.custom_tools:
|
||||
custom_tools[tool.name] = tool
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -560,7 +560,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name in custom_tools:
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
|
|
@ -656,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for tool in self.agent_config.custom_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
params = {}
|
||||
for param in tool.parameters:
|
||||
params[param.name] = ToolParamDefinition(
|
||||
|
|
@ -672,7 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
parameters=params,
|
||||
)
|
||||
)
|
||||
for tool_name in self.agent_config.available_tools:
|
||||
for tool_name in self.agent_config.tool_names:
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool.built_in_type:
|
||||
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from typing import Any, Dict, List
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
CustomToolDef,
|
||||
MCPToolGroupDef,
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
UserDefinedToolDef,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
)
|
||||
tools.append(
|
||||
CustomToolDef(
|
||||
UserDefinedToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
|
|||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"available_tools": [tool_name],
|
||||
"tool_names": [tool_name],
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from llama_models.llama3.api.datatypes import BuiltinTool
|
|||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import (
|
||||
BuiltInToolDef,
|
||||
CustomToolDef,
|
||||
ToolGroupInput,
|
||||
ToolParameter,
|
||||
UserDefinedToolDef,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
|
@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput:
|
|||
tool_group_id="memory_group",
|
||||
tool_group=UserDefinedToolGroupDef(
|
||||
tools=[
|
||||
CustomToolDef(
|
||||
UserDefinedToolDef(
|
||||
name="memory",
|
||||
description="Query the memory bank",
|
||||
parameters=[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue