This commit is contained in:
Xi Yan 2025-02-27 13:55:46 -08:00
parent 564f0e5f93
commit 58f9fd135b

View file

@ -17,6 +17,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from rich.pretty import pprint
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -125,13 +126,17 @@ class ChatAgent(ShieldRunnerMixin):
def turn_to_messages(self, turn: Turn) -> List[Message]: def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = [] messages = []
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages: for m in turn.input_messages:
msg = m.model_copy() msg = m.model_copy()
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
if isinstance(msg, UserMessage): if isinstance(msg, UserMessage):
msg.context = None msg.context = None
if isinstance(msg, ToolResponseMessage):
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
continue
messages.append(msg) messages.append(msg)
for step in turn.steps: for step in turn.steps:
@ -181,9 +186,20 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id) turns = await self.storage.get_session_turns(request.session_id)
print("!! create and execute turn turns", len(turns))
pprint(turns)
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
print("!! create and execute turn messages", len(messages))
pprint(messages)
messages.extend(request.messages) messages.extend(request.messages)
print("!! create and execute turn messages extended", len(messages))
pprint(messages)
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
start_time = datetime.now().astimezone().isoformat() start_time = datetime.now().astimezone().isoformat()
@ -265,17 +281,31 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id) turns = await self.storage.get_session_turns(request.session_id)
if len(turns) == 0:
raise ValueError("No turns found for session")
pprint("!! resume turn turns")
pprint(turns)
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses) messages.extend(request.tool_responses)
print("!! resume turn")
pprint(messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [ last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
] ]
print("last turn messages")
pprint(last_turn_messages)
# TODO: figure out whether we should add the tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses)
# get the steps from the turn id # get the steps from the turn id
steps = [] steps = []
if len(turns) > 0: steps = turns[-1].steps
steps = turns[-1].steps
# mark tool execution step as complete # mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), # if there's no tool execution in progress step (due to storage, or tool call parsing on client),
@ -375,6 +405,9 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
print("!!RUN input messages")
pprint(input_messages)
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a # return a "final value" for the `yield from` statement. we simulate that by yielding a
@ -419,6 +452,9 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
yield res yield res
pprint("!!RUN final response")
pprint(messages)
yield final_response yield final_response
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(