mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 21:29:57 +00:00
3/n
This commit is contained in:
parent
01f90dfe0c
commit
cd36a77e20
3 changed files with 69 additions and 26 deletions
|
@ -297,6 +297,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnContinueRequest(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: str
|
||||||
|
tool_responses: List[ToolResponseMessage]
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
"""streamed agent turn completion response."""
|
"""streamed agent turn completion response."""
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentToolGroupWithArgs,
|
AgentToolGroupWithArgs,
|
||||||
|
AgentTurnContinueRequest,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseEvent,
|
AgentTurnResponseEvent,
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
|
@ -30,7 +31,6 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
Attachment,
|
Attachment,
|
||||||
|
@ -227,25 +227,51 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
if output_message.tool_calls:
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
event=AgentTurnResponseEvent(
|
||||||
event=AgentTurnResponseEvent(
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
turn=turn,
|
||||||
turn=turn,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
|
||||||
turn=turn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("continue_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for i, turn in enumerate(turns):
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
|
||||||
|
messages.extend(request.messages)
|
||||||
|
|
||||||
|
# steps = []
|
||||||
|
# output_message = None
|
||||||
|
# async for chunk in self.run(
|
||||||
|
# session_id=request.session_id,
|
||||||
|
# turn_id=request.turn_id,
|
||||||
|
# input_messages=messages,
|
||||||
|
# sampling_params=self.agent_config.sampling_params,
|
||||||
|
# stream=request.stream,
|
||||||
|
# documents=request.documents,
|
||||||
|
# toolgroups_for_turn=request.toolgroups,
|
||||||
|
# ):
|
||||||
|
# if isinstance(chunk, CompletionMessage):
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -626,7 +652,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
# 1. Start the tool execution step and progress
|
tool_call = message.tool_calls[0]
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -636,8 +666,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -652,12 +680,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# If tool is a builtin server tool, execute it
|
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentSessionCreateResponse,
|
AgentSessionCreateResponse,
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
|
AgentTurnContinueRequest,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
|
@ -177,7 +178,18 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_responses: List[ToolResponseMessage],
|
tool_responses: List[ToolResponseMessage],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
pass
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnContinueRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.continue_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue