remove attachements, move memory bank to tool metadata

This commit is contained in:
Dinesh Yeduguru 2024-12-26 15:48:52 -08:00
parent 97798c8442
commit f408fd3aca
9 changed files with 45 additions and 180 deletions

View file

@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream
session_id, turn_id, input_messages, sampling_params, stream
):
if isinstance(res, bool):
return
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=session_id,
turn_id=turn_id,
input_messages=input_messages,
attachments=attachments,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -423,7 +419,10 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", tool_name)
if isinstance(tool_name, BuiltinTool):
span.set_attribute("tool_name", tool_name.value)
else:
span.set_attribute("tool_name", tool_name)
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
message.content += output_attachments
else:
message.content = [message.content] + attachments
message.content = [message.content] + output_attachments
yield message
else:
log.info(f"Partial message: {str(message)}")
@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,

View file

@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=True,
)
if stream: