wip refactor

This commit is contained in:
Xi Yan 2025-03-04 13:38:04 -08:00
parent e9a37bad63
commit e8f7fa1ce1

View file

@ -12,7 +12,7 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
AgentTurnResponseStreamChunk, AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Attachment, Attachment,
Document, Document,
@ -184,84 +183,88 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
start_time = datetime.now().astimezone().isoformat() assert request.stream is True, "Non-streaming not supported"
yield AgentTurnResponseStreamChunk( async for chunk in self._run_turn(request, turn_id):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
logcat.info(
"agents",
f"returning result from the agent turn: {chunk}",
)
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk yield chunk
assert output_message is not None async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
) -> AsyncGenerator:
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turn = Turn( turns = await self.storage.get_session_turns(request.session_id)
turn_id=turn_id, if is_resume and len(turns) == 0:
session_id=request.session_id, raise ValueError("No turns found for session")
input_messages=request.messages,
output_message=output_message, messages = await self.get_messages_from_turns(turns)
started_at=start_time, if is_resume:
completed_at=datetime.now().astimezone().isoformat(), messages.extend(request.tool_responses)
steps=steps, turn_id = request.turn_id
) start_time = turns[-1].started_at
await self.storage.add_turn_to_session(request.session_id, turn) else:
if output_message.tool_calls: messages.extend(request.messages)
chunk = AgentTurnResponseStreamChunk( start_time = datetime.now().astimezone().isoformat()
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload( steps = []
turn=turn, output_message = None
) async for chunk in self.run(
) session_id=request.session_id,
) turn_id=turn_id,
else: input_messages=messages,
chunk = AgentTurnResponseStreamChunk( sampling_params=self.agent_config.sampling_params,
event=AgentTurnResponseEvent( stream=request.stream,
payload=AgentTurnResponseTurnCompletePayload( documents=request.documents if not is_resume else None,
turn=turn, toolgroups_for_turn=request.toolgroups if not is_resume else None,
) ):
) if isinstance(chunk, CompletionMessage):
) output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span: with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)