address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-30 15:47:01 -08:00
parent 8bf3f8ea56
commit ac46bd5eb4
8 changed files with 24 additions and 35 deletions

View file

@ -429,9 +429,9 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
# Build a map of custom tools to their definitions for faster lookup
custom_tools = {}
for tool in self.agent_config.custom_tools:
custom_tools[tool.name] = tool
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -560,7 +560,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
if tool_call.tool_name in custom_tools:
if tool_call.tool_name in client_tools:
yield message
return
@ -656,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tools(self) -> List[ToolDefinition]:
ret = []
for tool in self.agent_config.custom_tools:
for tool in self.agent_config.client_tools:
params = {}
for param in tool.parameters:
params[param.name] = ToolParamDefinition(
@ -672,7 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
parameters=params,
)
)
for tool_name in self.agent_config.available_tools:
for tool_name in self.agent_config.tool_names:
tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type:
ret.append(ToolDefinition(tool_name=tool.built_in_type))

View file

@ -8,13 +8,13 @@ from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
CustomToolDef,
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
UserDefinedToolDef,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -53,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
)
tools.append(
CustomToolDef(
UserDefinedToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,

View file

@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
agent_config = AgentConfig(
**{
**common_params,
"available_tools": [tool_name],
"tool_names": [tool_name],
}
)

View file

@ -13,9 +13,9 @@ from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
BuiltInToolDef,
CustomToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolDef,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import Api, Provider
@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput:
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
CustomToolDef(
UserDefinedToolDef(
name="memory",
description="Query the memory bank",
parameters=[