mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
do not pass memory tools to inference
This commit is contained in:
parent
d0c8dced65
commit
b46d94d87d
2 changed files with 4 additions and 11 deletions
|
@ -487,20 +487,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference") as span:
|
with tracing.span("inference") as span:
|
||||||
|
|
||||||
def is_memory_group(tool):
|
|
||||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
|
||||||
has_memory_tool = MEMORY_QUERY_TOOL in tool_defs
|
|
||||||
return (
|
|
||||||
has_memory_tool
|
|
||||||
and tool_to_group.get(tool.tool_name, None) != memory_tool_group
|
|
||||||
)
|
|
||||||
|
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=[
|
||||||
tool for tool in tool_defs.values() if not is_memory_group(tool)
|
tool
|
||||||
|
for tool in tool_defs.values()
|
||||||
|
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||||
],
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -60,7 +60,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
description="Retrieve context from memory",
|
description="Retrieve context from memory",
|
||||||
parameters=[
|
parameters=[
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name="input_messages",
|
name="messages",
|
||||||
description="The input messages to search for",
|
description="The input messages to search for",
|
||||||
parameter_type="array",
|
parameter_type="array",
|
||||||
),
|
),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue