mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
do not pass memory tools to inference
This commit is contained in:
parent
d0c8dced65
commit
b46d94d87d
2 changed files with 4 additions and 11 deletions
|
@ -487,20 +487,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
|
||||
def is_memory_group(tool):
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
has_memory_tool = MEMORY_QUERY_TOOL in tool_defs
|
||||
return (
|
||||
has_memory_tool
|
||||
and tool_to_group.get(tool.tool_name, None) != memory_tool_group
|
||||
)
|
||||
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=[
|
||||
tool for tool in tool_defs.values() if not is_memory_group(tool)
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
|
|
|
@ -60,7 +60,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue