wip refactor

This commit is contained in:
Xi Yan 2025-03-04 13:38:04 -08:00
parent e9a37bad63
commit e8f7fa1ce1

View file

@ -12,7 +12,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import httpx
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
@ -184,26 +183,34 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
assert request.stream is True, "Non-streaming not supported"
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
) -> AsyncGenerator:
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
if is_resume and len(turns) == 0:
raise ValueError("No turns found for session")
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
messages = await self.get_messages_from_turns(turns)
if is_resume:
messages.extend(request.tool_responses)
turn_id = request.turn_id
start_time = turns[-1].started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = []
output_message = None
@ -213,14 +220,10 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups if not is_resume else None,
):
if isinstance(chunk, CompletionMessage):
logcat.info(
"agents",
f"returning result from the agent turn: {chunk}",
)
output_message = chunk
continue