mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix
This commit is contained in:
parent
564f0e5f93
commit
58f9fd135b
1 changed files with 42 additions and 6 deletions
|
@ -17,6 +17,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -125,13 +126,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# We do not want to keep adding RAG context to the input messages
|
|
||||||
# May be this should be a parameter of the agentic instance
|
|
||||||
# that can define its behavior in a custom way
|
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.model_copy()
|
msg = m.model_copy()
|
||||||
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
# May be this should be a parameter of the agentic instance
|
||||||
|
# that can define its behavior in a custom way
|
||||||
if isinstance(msg, UserMessage):
|
if isinstance(msg, UserMessage):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
|
if isinstance(msg, ToolResponseMessage):
|
||||||
|
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
|
@ -181,9 +186,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
|
||||||
|
print("!! create and execute turn turns", len(turns))
|
||||||
|
pprint(turns)
|
||||||
|
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
|
||||||
|
print("!! create and execute turn messages", len(messages))
|
||||||
|
pprint(messages)
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
|
print("!! create and execute turn messages extended", len(messages))
|
||||||
|
pprint(messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
start_time = datetime.now().astimezone().isoformat()
|
start_time = datetime.now().astimezone().isoformat()
|
||||||
|
@ -265,17 +281,31 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
if len(turns) == 0:
|
||||||
|
raise ValueError("No turns found for session")
|
||||||
|
|
||||||
|
pprint("!! resume turn turns")
|
||||||
|
pprint(turns)
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
print("!! resume turn")
|
||||||
|
pprint(messages)
|
||||||
|
|
||||||
|
last_turn = turns[-1]
|
||||||
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
print("last turn messages")
|
||||||
|
pprint(last_turn_messages)
|
||||||
|
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||||
|
last_turn_messages.extend(request.tool_responses)
|
||||||
|
|
||||||
# get the steps from the turn id
|
# get the steps from the turn id
|
||||||
steps = []
|
steps = []
|
||||||
if len(turns) > 0:
|
steps = turns[-1].steps
|
||||||
steps = turns[-1].steps
|
|
||||||
|
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
@ -375,6 +405,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
print("!!RUN input messages")
|
||||||
|
|
||||||
|
pprint(input_messages)
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
|
@ -419,6 +452,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
pprint("!!RUN final response")
|
||||||
|
pprint(messages)
|
||||||
|
|
||||||
yield final_response
|
yield final_response
|
||||||
|
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue