mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
better checking
This commit is contained in:
parent
e32ed65bef
commit
0560d1f4a2
1 changed files with 10 additions and 2 deletions
|
@ -125,6 +125,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
|
tool_call_ids = set()
|
||||||
|
for step in turn.steps:
|
||||||
|
if step.step_type == StepType.tool_execution.value:
|
||||||
|
for response in step.tool_responses:
|
||||||
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.model_copy()
|
msg = m.model_copy()
|
||||||
# We do not want to keep adding RAG context to the input messages
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
@ -133,8 +140,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if isinstance(msg, UserMessage):
|
if isinstance(msg, UserMessage):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
if isinstance(msg, ToolResponseMessage):
|
if isinstance(msg, ToolResponseMessage):
|
||||||
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
if msg.call_id in tool_call_ids:
|
||||||
continue
|
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue