do not pass memory tools to inference

This commit is contained in:
Dinesh Yeduguru 2025-01-08 18:53:32 -08:00
parent d0c8dced65
commit b46d94d87d
2 changed files with 4 additions and 11 deletions

View file

@ -487,20 +487,13 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None stop_reason = None
with tracing.span("inference") as span: with tracing.span("inference") as span:
def is_memory_group(tool):
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
has_memory_tool = MEMORY_QUERY_TOOL in tool_defs
return (
has_memory_tool
and tool_to_group.get(tool.tool_name, None) != memory_tool_group
)
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=[ tools=[
tool for tool in tool_defs.values() if not is_memory_group(tool) tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
], ],
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,

View file

@ -60,7 +60,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description="Retrieve context from memory", description="Retrieve context from memory",
parameters=[ parameters=[
ToolParameter( ToolParameter(
name="input_messages", name="messages",
description="The input messages to search for", description="The input messages to search for",
parameter_type="array", parameter_type="array",
), ),