mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
3/n
This commit is contained in:
parent
01f90dfe0c
commit
cd36a77e20
3 changed files with 69 additions and 26 deletions
|
@ -297,6 +297,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnContinueRequest(BaseModel):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponseMessage]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""streamed agent turn completion response."""
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.agents import (
|
|||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnContinueRequest,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
@ -30,7 +31,6 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
|
@ -227,25 +227,51 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator:
|
||||
with tracing.span("continue_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
|
||||
# steps = []
|
||||
# output_message = None
|
||||
# async for chunk in self.run(
|
||||
# session_id=request.session_id,
|
||||
# turn_id=request.turn_id,
|
||||
# input_messages=messages,
|
||||
# sampling_params=self.agent_config.sampling_params,
|
||||
# stream=request.stream,
|
||||
# documents=request.documents,
|
||||
# toolgroups_for_turn=request.toolgroups,
|
||||
# ):
|
||||
# if isinstance(chunk, CompletionMessage):
|
||||
|
||||
async def run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
@ -626,7 +652,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -636,8 +666,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -652,12 +680,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.agents import (
|
|||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnContinueRequest,
|
||||
AgentTurnCreateRequest,
|
||||
Document,
|
||||
Session,
|
||||
|
@ -177,7 +178,18 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
pass
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnContinueRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.continue_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue