From 58f9fd135b8f71b7e8d621faddd83c7d650d2c3a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 13:55:46 -0800 Subject: [PATCH] fix --- .../agents/meta_reference/agent_instance.py | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e14a35463..eb3257bce 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,6 +17,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx +from rich.pretty import pprint from llama_stack.apis.agents import ( AgentConfig, @@ -125,13 +126,17 @@ class ChatAgent(ShieldRunnerMixin): def turn_to_messages(self, turn: Turn) -> List[Message]: messages = [] - # We do not want to keep adding RAG context to the input messages - # May be this should be a parameter of the agentic instance - # that can define its behavior in a custom way for m in turn.input_messages: msg = m.model_copy() + # We do not want to keep adding RAG context to the input messages + # May be this should be a parameter of the agentic instance + # that can define its behavior in a custom way if isinstance(msg, UserMessage): msg.context = None + if isinstance(msg, ToolResponseMessage): + # NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps + continue + messages.append(msg) for step in turn.steps: @@ -181,9 +186,20 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) + + print("!! create and execute turn turns", len(turns)) + pprint(turns) + messages = await self.get_messages_from_turns(turns) + + print("!! create and execute turn messages", len(messages)) + pprint(messages) + messages.extend(request.messages) + print("!! create and execute turn messages extended", len(messages)) + pprint(messages) + turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) start_time = datetime.now().astimezone().isoformat() @@ -265,17 +281,31 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) + if len(turns) == 0: + raise ValueError("No turns found for session") + + pprint("!! resume turn turns") + pprint(turns) messages = await self.get_messages_from_turns(turns) messages.extend(request.tool_responses) + print("!! resume turn") + pprint(messages) + + last_turn = turns[-1] + last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ - x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] + print("last turn messages") + pprint(last_turn_messages) + # TODO: figure out whether we should add the tool responses to the last turn messages + last_turn_messages.extend(request.tool_responses) + # get the steps from the turn id steps = [] - if len(turns) > 0: - steps = turns[-1].steps + steps = turns[-1].steps # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client), @@ -375,6 +405,9 @@ class ChatAgent(ShieldRunnerMixin): documents: Optional[List[Document]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: + print("!!RUN input messages") + + pprint(input_messages) # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot # return a "final value" for the `yield from` statement. we simulate that by yielding a @@ -419,6 +452,9 @@ class ChatAgent(ShieldRunnerMixin): else: yield res + pprint("!!RUN final response") + pprint(messages) + yield final_response async def run_multiple_shields_wrapper(